Skip to content

Commit b3852bf

Browse files
fix: Adopt connection pooling for HBase (#3793)
1 parent 175d796 commit b3852bf

File tree

3 files changed

+107
-81
lines changed

3 files changed

+107
-81
lines changed

sdk/python/feast/feature_store.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,11 @@ def _list_feature_views(
287287
for fv in self._registry.list_feature_views(
288288
self.project, allow_cache=allow_cache
289289
):
290-
if hide_dummy_entity and fv.entities[0] == DUMMY_ENTITY_NAME:
290+
if (
291+
hide_dummy_entity
292+
and fv.entities
293+
and fv.entities[0] == DUMMY_ENTITY_NAME
294+
):
291295
fv.entities = []
292296
fv.entity_columns = []
293297
feature_views.append(fv)

sdk/python/feast/infra/online_stores/contrib/hbase_online_store/hbase.py

+24-31
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33
from datetime import datetime
44
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
55

6-
from happybase import Connection
6+
from happybase import ConnectionPool
7+
from happybase.connection import DEFAULT_PROTOCOL, DEFAULT_TRANSPORT
8+
from pydantic import StrictStr
79
from pydantic.typing import Literal
810

911
from feast import Entity
1012
from feast.feature_view import FeatureView
1113
from feast.infra.online_stores.helpers import compute_entity_id
1214
from feast.infra.online_stores.online_store import OnlineStore
13-
from feast.infra.utils.hbase_utils import HbaseConstants, HbaseUtils
15+
from feast.infra.utils.hbase_utils import HBaseConnector, HbaseConstants
1416
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
1517
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
1618
from feast.repo_config import FeastConfigBaseModel, RepoConfig
@@ -23,35 +25,20 @@ class HbaseOnlineStoreConfig(FeastConfigBaseModel):
2325
type: Literal["hbase"] = "hbase"
2426
"""Online store type selector"""
2527

26-
host: str
28+
host: StrictStr
2729
"""Hostname of Hbase Thrift server"""
2830

29-
port: str
31+
port: StrictStr
3032
"""Port in which Hbase Thrift server is running"""
3133

34+
connection_pool_size: int = 4
35+
"""Number of connections to Hbase Thrift server to keep in the connection pool"""
3236

33-
class HbaseConnection:
34-
"""
35-
Hbase connecttion to connect to hbase.
36-
37-
Attributes:
38-
store_config: Online store config for Hbase store.
39-
"""
37+
protocol: StrictStr = DEFAULT_PROTOCOL
38+
"""Protocol used to communicate with Hbase Thrift server"""
4039

41-
def __init__(self, store_config: HbaseOnlineStoreConfig):
42-
self._store_config = store_config
43-
self._real_conn = Connection(
44-
host=store_config.host, port=int(store_config.port)
45-
)
46-
47-
@property
48-
def real_conn(self) -> Connection:
49-
"""Stores the real happybase Connection to connect to hbase."""
50-
return self._real_conn
51-
52-
def close(self) -> None:
53-
"""Close the happybase connection."""
54-
self.real_conn.close()
40+
transport: StrictStr = DEFAULT_TRANSPORT
41+
"""Transport used to communicate with Hbase Thrift server"""
5542

5643

5744
class HbaseOnlineStore(OnlineStore):
@@ -62,7 +49,7 @@ class HbaseOnlineStore(OnlineStore):
6249
_conn: Happybase Connection to connect to hbase thrift server.
6350
"""
6451

65-
_conn: Connection = None
52+
_conn: ConnectionPool = None
6653

6754
def _get_conn(self, config: RepoConfig):
6855
"""
@@ -76,7 +63,13 @@ def _get_conn(self, config: RepoConfig):
7663
assert isinstance(store_config, HbaseOnlineStoreConfig)
7764

7865
if not self._conn:
79-
self._conn = Connection(host=store_config.host, port=int(store_config.port))
66+
self._conn = ConnectionPool(
67+
host=store_config.host,
68+
port=int(store_config.port),
69+
size=int(store_config.connection_pool_size),
70+
protocol=store_config.protocol,
71+
transport=store_config.transport,
72+
)
8073
return self._conn
8174

8275
@log_exceptions_and_usage(online_store="hbase")
@@ -102,7 +95,7 @@ def online_write_batch(
10295
the online store. Can be used to display progress.
10396
"""
10497

105-
hbase = HbaseUtils(self._get_conn(config))
98+
hbase = HBaseConnector(self._get_conn(config))
10699
project = config.project
107100
table_name = self._table_id(project, table)
108101

@@ -154,7 +147,7 @@ def online_read(
154147
entity_keys: a list of entity keys that should be read from the FeatureStore.
155148
requested_features: a list of requested feature names.
156149
"""
157-
hbase = HbaseUtils(self._get_conn(config))
150+
hbase = HBaseConnector(self._get_conn(config))
158151
project = config.project
159152
table_name = self._table_id(project, table)
160153

@@ -206,7 +199,7 @@ def update(
206199
tables_to_delete: Tables to delete from the Hbase Online Store.
207200
tables_to_keep: Tables to keep in the Hbase Online Store.
208201
"""
209-
hbase = HbaseUtils(self._get_conn(config))
202+
hbase = HBaseConnector(self._get_conn(config))
210203
project = config.project
211204

212205
# We don't create any special state for the entites in this implementation.
@@ -232,7 +225,7 @@ def teardown(
232225
config: The RepoConfig for the current FeatureStore.
233226
tables: Tables to delete from the feature repo.
234227
"""
235-
hbase = HbaseUtils(self._get_conn(config))
228+
hbase = HBaseConnector(self._get_conn(config))
236229
project = config.project
237230

238231
for table in tables:

sdk/python/feast/infra/utils/hbase_utils.py

+78-49
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
from typing import List
22

3-
from happybase import Connection
4-
5-
from feast.infra.key_encoding_utils import serialize_entity_key
6-
from feast.protos.feast.types.EntityKey_pb2 import EntityKey
3+
from happybase import ConnectionPool
74

85

96
class HbaseConstants:
@@ -28,7 +25,7 @@ def get_col_from_feature(feature):
2825
return HbaseConstants.DEFAULT_COLUMN_FAMILY + ":" + feature
2926

3027

31-
class HbaseUtils:
28+
class HBaseConnector:
3229
"""
3330
Utils class to manage different Hbase operations.
3431
@@ -40,14 +37,22 @@ class HbaseUtils:
4037
"""
4138

4239
def __init__(
43-
self, conn: Connection = None, host: str = None, port: int = None, timeout=None
40+
self,
41+
pool: ConnectionPool = None,
42+
host: str = None,
43+
port: int = None,
44+
connection_pool_size: int = 4,
4445
):
45-
if conn is None:
46+
if pool is None:
4647
self.host = host
4748
self.port = port
48-
self.conn = Connection(host=host, port=port, timeout=timeout)
49+
self.pool = ConnectionPool(
50+
host=host,
51+
port=port,
52+
size=connection_pool_size,
53+
)
4954
else:
50-
self.conn = conn
55+
self.pool = pool
5156

5257
def create_table(self, table_name: str, colm_family: List[str]):
5358
"""
@@ -60,7 +65,9 @@ def create_table(self, table_name: str, colm_family: List[str]):
6065
cf_dict: dict = {}
6166
for cf in colm_family:
6267
cf_dict[cf] = dict()
63-
return self.conn.create_table(table_name, cf_dict)
68+
69+
with self.pool.connection() as conn:
70+
return conn.create_table(table_name, cf_dict)
6471

6572
def create_table_with_default_cf(self, table_name: str):
6673
"""
@@ -69,7 +76,8 @@ def create_table_with_default_cf(self, table_name: str):
6976
Arguments:
7077
table_name: Name of the Hbase table.
7178
"""
72-
return self.conn.create_table(table_name, {"default": dict()})
79+
with self.pool.connection() as conn:
80+
return conn.create_table(table_name, {"default": dict()})
7381

7482
def check_if_table_exist(self, table_name: str):
7583
"""
@@ -78,16 +86,18 @@ def check_if_table_exist(self, table_name: str):
7886
Arguments:
7987
table_name: Name of the Hbase table.
8088
"""
81-
return bytes(table_name, "utf-8") in self.conn.tables()
89+
with self.pool.connection() as conn:
90+
return bytes(table_name, "utf-8") in conn.tables()
8291

8392
def batch(self, table_name: str):
8493
"""
85-
Returns a 'Batch' instance that can be used for mass data manipulation in the hbase table.
94+
Returns a "Batch" instance that can be used for mass data manipulation in the hbase table.
8695
8796
Arguments:
8897
table_name: Name of the Hbase table.
8998
"""
90-
return self.conn.table(table_name).batch()
99+
with self.pool.connection() as conn:
100+
return conn.table(table_name).batch()
91101

92102
def put(self, table_name: str, row_key: str, data: dict):
93103
"""
@@ -98,8 +108,9 @@ def put(self, table_name: str, row_key: str, data: dict):
98108
row_key: Row key of the row to be inserted to hbase table.
99109
data: Mapping of column family name:column name to column values
100110
"""
101-
table = self.conn.table(table_name)
102-
table.put(row_key, data)
111+
with self.pool.connection() as conn:
112+
table = conn.table(table_name)
113+
table.put(row_key, data)
103114

104115
def row(
105116
self,
@@ -119,8 +130,9 @@ def row(
119130
timestamp: timestamp specifies the maximum version the cells can have.
120131
include_timestamp: specifies if (column, timestamp) to be return instead of only column.
121132
"""
122-
table = self.conn.table(table_name)
123-
return table.row(row_key, columns, timestamp, include_timestamp)
133+
with self.pool.connection() as conn:
134+
table = conn.table(table_name)
135+
return table.row(row_key, columns, timestamp, include_timestamp)
124136

125137
def rows(
126138
self,
@@ -140,52 +152,69 @@ def rows(
140152
timestamp: timestamp specifies the maximum version the cells can have.
141153
include_timestamp: specifies if (column, timestamp) to be return instead of only column.
142154
"""
143-
table = self.conn.table(table_name)
144-
return table.rows(row_keys, columns, timestamp, include_timestamp)
155+
with self.pool.connection() as conn:
156+
table = conn.table(table_name)
157+
return table.rows(row_keys, columns, timestamp, include_timestamp)
145158

146159
def print_table(self, table_name):
147160
"""Prints the table scanning all the rows of the hbase table."""
148-
table = self.conn.table(table_name)
149-
scan_data = table.scan()
150-
for row_key, cols in scan_data:
151-
print(row_key.decode("utf-8"), cols)
161+
with self.pool.connection() as conn:
162+
table = conn.table(table_name)
163+
scan_data = table.scan()
164+
for row_key, cols in scan_data:
165+
print(row_key.decode("utf-8"), cols)
152166

153167
def delete_table(self, table: str):
154168
"""Deletes the hbase table given the table name."""
155169
if self.check_if_table_exist(table):
156-
self.conn.delete_table(table, disable=True)
170+
with self.pool.connection() as conn:
171+
conn.delete_table(table, disable=True)
157172

158173
def close_conn(self):
159174
"""Closes the happybase connection."""
160-
self.conn.close()
175+
with self.pool.connection() as conn:
176+
conn.close()
161177

162178

163179
def main():
180+
from feast.infra.key_encoding_utils import serialize_entity_key
181+
from feast.protos.feast.types.EntityKey_pb2 import EntityKey
164182
from feast.protos.feast.types.Value_pb2 import Value
165183

166-
connection = Connection(host="localhost", port=9090)
167-
table = connection.table("test_hbase_driver_hourly_stats")
168-
row_keys = [
169-
serialize_entity_key(
170-
EntityKey(join_keys=["driver_id"], entity_values=[Value(int64_val=1004)]),
171-
entity_key_serialization_version=2,
172-
).hex(),
173-
serialize_entity_key(
174-
EntityKey(join_keys=["driver_id"], entity_values=[Value(int64_val=1005)]),
175-
entity_key_serialization_version=2,
176-
).hex(),
177-
serialize_entity_key(
178-
EntityKey(join_keys=["driver_id"], entity_values=[Value(int64_val=1024)]),
179-
entity_key_serialization_version=2,
180-
).hex(),
181-
]
182-
rows = table.rows(row_keys)
183-
184-
for row_key, row in rows:
185-
for key, value in row.items():
186-
col_name = bytes.decode(key, "utf-8").split(":")[1]
187-
print(col_name, value)
188-
print()
184+
pool = ConnectionPool(
185+
host="localhost",
186+
port=9090,
187+
size=2,
188+
)
189+
with pool.connection() as connection:
190+
table = connection.table("test_hbase_driver_hourly_stats")
191+
row_keys = [
192+
serialize_entity_key(
193+
EntityKey(
194+
join_keys=["driver_id"], entity_values=[Value(int64_val=1004)]
195+
),
196+
entity_key_serialization_version=2,
197+
).hex(),
198+
serialize_entity_key(
199+
EntityKey(
200+
join_keys=["driver_id"], entity_values=[Value(int64_val=1005)]
201+
),
202+
entity_key_serialization_version=2,
203+
).hex(),
204+
serialize_entity_key(
205+
EntityKey(
206+
join_keys=["driver_id"], entity_values=[Value(int64_val=1024)]
207+
),
208+
entity_key_serialization_version=2,
209+
).hex(),
210+
]
211+
rows = table.rows(row_keys)
212+
213+
for _, row in rows:
214+
for key, value in row.items():
215+
col_name = bytes.decode(key, "utf-8").split(":")[1]
216+
print(col_name, value)
217+
print()
189218

190219

191220
if __name__ == "__main__":

0 commit comments

Comments
 (0)