Skip to content

Commit 360a652

Browse files
committed
refactor: change IsolationLevel to Enum
Signed-off-by: Đặng Minh Dũng <dungdm93@live.com>
1 parent 744c4dc commit 360a652

File tree

3 files changed

+34
-26
lines changed

3 files changed

+34
-26
lines changed

tests/sqlalchemy/test_dialect.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from typing import List, Any, Dict
2+
from unittest import mock
23

34
import pytest
45
from assertpy import assert_that
56
from sqlalchemy.engine import make_url
67
from sqlalchemy.engine.url import URL
78

89
from trino.auth import BasicAuthentication
10+
from trino.dbapi import Connection
911
from trino.sqlalchemy.dialect import TrinoDialect
12+
from trino.transaction import IsolationLevel
1013

1114

1215
class TestTrinoDialect:
@@ -32,3 +35,16 @@ def test_create_connect_args(self, url: URL, expected_args: List[Any], expected_
3235

3336
assert_that(actual_args).is_equal_to(expected_args)
3437
assert_that(actual_kwargs).is_equal_to(expected_kwargs)
38+
39+
def test_get_default_isolation_level(self):
40+
isolation_level = self.dialect.get_default_isolation_level(mock.Mock())
41+
assert_that(isolation_level).is_equal_to('AUTOCOMMIT')
42+
43+
def test_isolation_level(self):
44+
dbapi_conn = Connection(host="localhost")
45+
46+
self.dialect.set_isolation_level(dbapi_conn, "SERIALIZABLE")
47+
assert_that(dbapi_conn._isolation_level).is_equal_to(IsolationLevel.SERIALIZABLE)
48+
49+
isolation_level = self.dialect.get_isolation_level(dbapi_conn)
50+
assert_that(isolation_level).is_equal_to("SERIALIZABLE")

trino/sqlalchemy/dialect.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from trino import dbapi as trino_dbapi
2121
from trino.auth import BasicAuthentication
2222
from trino.dbapi import Cursor
23-
from . import compiler, datatype, error
23+
from trino.sqlalchemy import compiler, datatype, error
2424

2525

2626
class TrinoDialect(DefaultDialect):
@@ -291,16 +291,14 @@ def do_commit_twophase(self, connection: Connection, xid: str,
291291
def do_recover_twophase(self, connection: Connection) -> None:
292292
pass
293293

294-
def set_isolation_level(self, dbapi_conn: trino_dbapi.Connection, level) -> None:
295-
dbapi_conn._isolation_level = getattr(trino_dbapi.IsolationLevel, level)
294+
def set_isolation_level(self, dbapi_conn: trino_dbapi.Connection, level: str) -> None:
295+
dbapi_conn._isolation_level = trino_dbapi.IsolationLevel[level]
296296

297297
def get_isolation_level(self, dbapi_conn: trino_dbapi.Connection) -> str:
298-
level_names = ['AUTOCOMMIT',
299-
'READ_UNCOMMITTED',
300-
'READ_COMMITTED',
301-
'REPEATABLE_READ',
302-
'SERIALIZABLE']
303-
return level_names[dbapi_conn.isolation_level]
298+
return dbapi_conn.isolation_level.name
299+
300+
def get_default_isolation_level(self, dbapi_conn: trino_dbapi.Connection) -> str:
301+
return trino_dbapi.IsolationLevel.AUTOCOMMIT.name
304302

305303
def _get_full_table(self, table_name: str, schema: str = None, quote: bool = True) -> str:
306304
table_part = self.identifier_preparer.quote_identifier(table_name) if quote else table_name

trino/transaction.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,24 @@
99
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
12+
from enum import Enum, unique
1213
from typing import Iterable
1314

14-
from trino import constants
1515
import trino.client
1616
import trino.exceptions
1717
import trino.logging
18-
18+
from trino import constants
1919

2020
logger = trino.logging.get_logger(__name__)
2121

22-
2322
NO_TRANSACTION = "NONE"
2423
START_TRANSACTION = "START TRANSACTION"
2524
ROLLBACK = "ROLLBACK"
2625
COMMIT = "COMMIT"
2726

2827

29-
class IsolationLevel(object):
28+
@unique
29+
class IsolationLevel(Enum):
3030
AUTOCOMMIT = 0
3131
READ_UNCOMMITTED = 1
3232
READ_COMMITTED = 2
@@ -35,16 +35,16 @@ class IsolationLevel(object):
3535

3636
@classmethod
3737
def levels(cls) -> Iterable[str]:
38-
return {k for k, v in cls.__dict__.items() if not k.startswith("_") and isinstance(v, int)}
38+
return {isolation_level.name for isolation_level in IsolationLevel}
3939

4040
@classmethod
4141
def values(cls) -> Iterable[int]:
42-
return {getattr(cls, level) for level in cls.levels()}
42+
return {isolation_level.value for isolation_level in IsolationLevel}
4343

4444
@classmethod
4545
def check(cls, level: int) -> int:
4646
if level not in cls.values():
47-
raise ValueError("invalid isolation level {}".format(level))
47+
raise ValueError(f"invalid isolation level {level}")
4848
return level
4949

5050

@@ -60,9 +60,7 @@ def id(self):
6060
def begin(self):
6161
response = self._request.post(START_TRANSACTION)
6262
if not response.ok:
63-
raise trino.exceptions.DatabaseError(
64-
"failed to start transaction: {}".format(response.status_code)
65-
)
63+
raise trino.exceptions.DatabaseError(f"failed to start transaction: {response.status_code}")
6664
transaction_id = response.headers.get(constants.HEADER_STARTED_TRANSACTION)
6765
if transaction_id and transaction_id != NO_TRANSACTION:
6866
self._id = response.headers[constants.HEADER_STARTED_TRANSACTION]
@@ -74,16 +72,14 @@ def begin(self):
7472
self._id = response.headers[constants.HEADER_STARTED_TRANSACTION]
7573
status = self._request.process(response)
7674
self._request.transaction_id = self._id
77-
logger.info("transaction started: " + self._id)
75+
logger.info("transaction started: %s", self._id)
7876

7977
def commit(self):
8078
query = trino.client.TrinoQuery(self._request, COMMIT)
8179
try:
8280
list(query.execute())
8381
except Exception as err:
84-
raise trino.exceptions.DatabaseError(
85-
"failed to commit transaction {}: {}".format(self._id, err)
86-
)
82+
raise trino.exceptions.DatabaseError(f"failed to commit transaction {self._id}") from err
8783
self._id = NO_TRANSACTION
8884
self._request.transaction_id = self._id
8985

@@ -92,8 +88,6 @@ def rollback(self):
9288
try:
9389
list(query.execute())
9490
except Exception as err:
95-
raise trino.exceptions.DatabaseError(
96-
"failed to rollback transaction {}: {}".format(self._id, err)
97-
)
91+
raise trino.exceptions.DatabaseError(f"failed to rollback transaction {self._id}") from err
9892
self._id = NO_TRANSACTION
9993
self._request.transaction_id = self._id

0 commit comments

Comments
 (0)