Skip to content

Commit c5d19b5

Browse files
committed
feat: support sqlalchemy
Signed-off-by: Đặng Minh Dũng <dungdm93@live.com>
1 parent c8144de commit c5d19b5

18 files changed

+1196
-55
lines changed

integration_tests/__init__.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.

setup.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,23 @@
1414

1515
import ast
1616
import re
17-
from setuptools import setup
1817
import textwrap
1918

19+
from setuptools import setup
2020

2121
_version_re = re.compile(r"__version__\s+=\s+(.*)")
2222

23-
2423
with open("trino/__init__.py", "rb") as f:
2524
trino_version = _version_re.search(f.read().decode("utf-8"))
2625
assert trino_version is not None
2726
version = str(ast.literal_eval(trino_version.group(1)))
2827

29-
3028
kerberos_require = ["requests_kerberos"]
29+
sqlalchemy_require = ["sqlalchemy~=1.3"]
3130

32-
all_require = kerberos_require + []
31+
all_require = kerberos_require + sqlalchemy_require
3332

34-
tests_require = all_require + ["httpretty", "pytest", "pytest-runner", "pytz", "click"]
33+
tests_require = all_require + ["httpretty", "pytest", "pytest-runner", "pytz", "click", "assertpy"]
3534

3635
setup(
3736
name="trino",
@@ -44,19 +43,17 @@
4443
description="Client for the Trino distributed SQL Engine",
4544
long_description=textwrap.dedent(
4645
"""
47-
Client for Trino (https://trino.io), a distributed SQL engine for
48-
interactive and batch big data processing. Provides a low-level client and
49-
a DBAPI 2.0 implementation.
50-
"""
46+
Client for Trino (https://trino.io), a distributed SQL engine for
47+
interactive and batch big data processing. Provides a low-level client and
48+
a DBAPI 2.0 implementation.
49+
"""
5150
),
5251
license="Apache 2.0",
5352
classifiers=[
5453
"Development Status :: 4 - Beta",
5554
"Intended Audience :: Developers",
5655
"License :: OSI Approved :: Apache Software License",
57-
"Operating System :: MacOS :: MacOS X",
58-
"Operating System :: POSIX",
59-
"Operating System :: Microsoft :: Windows",
56+
"Operating System :: OS Independent",
6057
"Programming Language :: Python",
6158
"Programming Language :: Python :: 3",
6259
"Programming Language :: Python :: 3.6",
@@ -66,13 +63,20 @@
6663
"Programming Language :: Python :: 3.10",
6764
"Programming Language :: Python :: Implementation :: CPython",
6865
"Programming Language :: Python :: Implementation :: PyPy",
66+
"Topic :: Database",
6967
"Topic :: Database :: Front-Ends",
7068
],
7169
python_requires='>=3.6',
7270
install_requires=["requests"],
7371
extras_require={
7472
"all": all_require,
7573
"kerberos": kerberos_require,
74+
"sqlalchemy": sqlalchemy_require,
7675
"tests": tests_require,
7776
},
77+
entry_points={
78+
"sqlalchemy.dialects": [
79+
"trino = trino.sqlalchemy.dialect:TrinoDialect",
80+
]
81+
},
7882
)

tests/__init__.py

Whitespace-only changes.

tests/integration/test_dbapi_integration.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import pytest
1616
import pytz
1717

18-
import trino
18+
import trino.dbapi
1919
from trino.exceptions import TrinoQueryError
2020
from trino.transaction import IsolationLevel
2121

tests/sqlalchemy/__init__.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.

tests/sqlalchemy/conftest.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
from assertpy import add_extension, assert_that
13+
from sqlalchemy.sql.sqltypes import ARRAY
14+
15+
from trino.sqlalchemy.datatype import MAP, ROW, SQLType
16+
17+
18+
def assert_sqltype(this: SQLType, that: SQLType):
19+
if isinstance(this, type):
20+
this = this()
21+
if isinstance(that, type):
22+
that = that()
23+
assert_that(type(this)).is_same_as(type(that))
24+
if isinstance(this, ARRAY):
25+
assert_sqltype(this.item_type, that.item_type)
26+
if this.dimensions is None or this.dimensions == 1:
27+
# ARRAY(dimensions=None) == ARRAY(dimensions=1)
28+
assert_that(that.dimensions).is_in(None, 1)
29+
else:
30+
assert_that(this.dimensions).is_equal_to(this.dimensions)
31+
elif isinstance(this, MAP):
32+
assert_sqltype(this.key_type, that.key_type)
33+
assert_sqltype(this.value_type, that.value_type)
34+
elif isinstance(this, ROW):
35+
assert_that(len(this.attr_types)).is_equal_to(len(that.attr_types))
36+
for (this_attr, that_attr) in zip(this.attr_types, that.attr_types):
37+
assert_that(this_attr[0]).is_equal_to(that_attr[0])
38+
assert_sqltype(this_attr[1], that_attr[1])
39+
else:
40+
assert_that(str(this)).is_equal_to(str(that))
41+
42+
43+
@add_extension
44+
def is_sqltype(self, that):
45+
this = self.val
46+
assert_sqltype(this, that)
+183
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
import pytest
13+
from assertpy import assert_that
14+
from sqlalchemy.sql.sqltypes import (
15+
CHAR, VARCHAR,
16+
ARRAY,
17+
INTEGER, DECIMAL,
18+
DATE, TIME, TIMESTAMP
19+
)
20+
from sqlalchemy.sql.type_api import TypeEngine
21+
22+
from trino.sqlalchemy import datatype
23+
from trino.sqlalchemy.datatype import MAP, ROW
24+
25+
26+
@pytest.mark.parametrize(
27+
'type_str, sql_type',
28+
datatype._type_map.items(),
29+
ids=datatype._type_map.keys()
30+
)
31+
def test_parse_simple_type(type_str: str, sql_type: TypeEngine):
32+
actual_type = datatype.parse_sqltype(type_str)
33+
if not isinstance(actual_type, type):
34+
actual_type = type(actual_type)
35+
assert_that(actual_type).is_equal_to(sql_type)
36+
37+
38+
parse_cases_testcases = {
39+
'char(10)': CHAR(10),
40+
'Char(10)': CHAR(10),
41+
'char': CHAR(),
42+
'cHaR': CHAR(),
43+
'VARCHAR(10)': VARCHAR(10),
44+
'varCHAR(10)': VARCHAR(10),
45+
'VARchar(10)': VARCHAR(10),
46+
'VARCHAR': VARCHAR(),
47+
'VaRchAr': VARCHAR(),
48+
}
49+
50+
51+
@pytest.mark.parametrize(
52+
'type_str, sql_type',
53+
parse_cases_testcases.items(),
54+
ids=parse_cases_testcases.keys()
55+
)
56+
def test_parse_cases(type_str: str, sql_type: TypeEngine):
57+
actual_type = datatype.parse_sqltype(type_str)
58+
assert_that(actual_type).is_sqltype(sql_type)
59+
60+
61+
parse_type_options_testcases = {
62+
'CHAR(10)': CHAR(10),
63+
'VARCHAR(10)': VARCHAR(10),
64+
'DECIMAL(20)': DECIMAL(20),
65+
'DECIMAL(20, 3)': DECIMAL(20, 3),
66+
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
67+
}
68+
69+
70+
@pytest.mark.parametrize(
71+
'type_str, sql_type',
72+
parse_type_options_testcases.items(),
73+
ids=parse_type_options_testcases.keys()
74+
)
75+
def test_parse_type_options(type_str: str, sql_type: TypeEngine):
76+
actual_type = datatype.parse_sqltype(type_str)
77+
assert_that(actual_type).is_sqltype(sql_type)
78+
79+
80+
parse_array_testcases = {
81+
'array(integer)': ARRAY(INTEGER()),
82+
'array(varchar(10))': ARRAY(VARCHAR(10)),
83+
'array(decimal(20,3))': ARRAY(DECIMAL(20, 3)),
84+
'array(array(varchar(10)))': ARRAY(VARCHAR(10), dimensions=2),
85+
'array(map(char, integer))': ARRAY(MAP(CHAR(), INTEGER())),
86+
'array(row(a integer, b varchar))': ARRAY(ROW([("a", INTEGER()), ("b", VARCHAR())])),
87+
}
88+
89+
90+
@pytest.mark.parametrize(
91+
'type_str, sql_type',
92+
parse_array_testcases.items(),
93+
ids=parse_array_testcases.keys()
94+
)
95+
def test_parse_array(type_str: str, sql_type: ARRAY):
96+
actual_type = datatype.parse_sqltype(type_str)
97+
assert_that(actual_type).is_sqltype(sql_type)
98+
99+
100+
parse_map_testcases = {
101+
'map(char, integer)': MAP(CHAR(), INTEGER()),
102+
'map(varchar(10), varchar(10))': MAP(VARCHAR(10), VARCHAR(10)),
103+
'map(varchar(10), decimal(20,3))': MAP(VARCHAR(10), DECIMAL(20, 3)),
104+
'map(char, array(varchar(10)))': MAP(CHAR(), ARRAY(VARCHAR(10))),
105+
'map(varchar(10), array(varchar(10)))': MAP(VARCHAR(10), ARRAY(VARCHAR(10))),
106+
'map(varchar(10), array(array(varchar(10))))': MAP(VARCHAR(10), ARRAY(VARCHAR(10), dimensions=2)),
107+
}
108+
109+
110+
@pytest.mark.parametrize(
111+
'type_str, sql_type',
112+
parse_map_testcases.items(),
113+
ids=parse_map_testcases.keys()
114+
)
115+
def test_parse_map(type_str: str, sql_type: ARRAY):
116+
actual_type = datatype.parse_sqltype(type_str)
117+
assert_that(actual_type).is_sqltype(sql_type)
118+
119+
120+
parse_row_testcases = {
121+
'row(a integer, b varchar)':
122+
ROW(attr_types=[
123+
("a", INTEGER()),
124+
("b", VARCHAR()),
125+
]),
126+
'row(a varchar(20), b decimal(20,3))':
127+
ROW(attr_types=[
128+
("a", VARCHAR(20)),
129+
("b", DECIMAL(20, 3)),
130+
]),
131+
'row(x array(varchar(10)), y array(array(varchar(10))), z decimal(20,3))':
132+
ROW(attr_types=[
133+
("x", ARRAY(VARCHAR(10))),
134+
("y", ARRAY(VARCHAR(10), dimensions=2)),
135+
("z", DECIMAL(20, 3)),
136+
]),
137+
'row(min timestamp(6) with time zone, max timestamp(6) with time zone)':
138+
ROW(attr_types=[
139+
("min", TIMESTAMP(timezone=True)),
140+
("max", TIMESTAMP(timezone=True)),
141+
]),
142+
'row("first name" varchar, "last name" varchar)':
143+
ROW(attr_types=[
144+
("first name", VARCHAR()),
145+
("last name", VARCHAR()),
146+
]),
147+
'row("foo,bar" varchar, "foo(bar)" varchar, "foo\\"bar" varchar)':
148+
ROW(attr_types=[
149+
(r'foo,bar', VARCHAR()),
150+
(r'foo(bar)', VARCHAR()),
151+
(r'foo"bar', VARCHAR()),
152+
]),
153+
}
154+
155+
156+
@pytest.mark.parametrize(
157+
'type_str, sql_type',
158+
parse_row_testcases.items(),
159+
ids=parse_row_testcases.keys()
160+
)
161+
def test_parse_row(type_str: str, sql_type: ARRAY):
162+
actual_type = datatype.parse_sqltype(type_str)
163+
assert_that(actual_type).is_sqltype(sql_type)
164+
165+
166+
parse_datetime_testcases = {
167+
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
168+
'date': DATE(),
169+
'time': TIME(),
170+
'time with time zone': TIME(timezone=True),
171+
'timestamp': TIMESTAMP(),
172+
'timestamp with time zone': TIMESTAMP(timezone=True),
173+
}
174+
175+
176+
@pytest.mark.parametrize(
177+
'type_str, sql_type',
178+
parse_datetime_testcases.items(),
179+
ids=parse_datetime_testcases.keys()
180+
)
181+
def test_parse_datetime(type_str: str, sql_type: ARRAY):
182+
actual_type = datatype.parse_sqltype(type_str)
183+
assert_that(actual_type).is_sqltype(sql_type)

0 commit comments

Comments
 (0)