Skip to content

Commit eb1ed24

Browse files
committed
minimal changes for sqlalchemy 2.0 support
1 parent 4367cc5 commit eb1ed24

7 files changed

+220
-97
lines changed

README.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ First install this package to register it with SQLAlchemy (see ``setup.py``).
7474
engine = create_engine('trino://localhost:8080/hive/default')
7575
# Hive
7676
engine = create_engine('hive://localhost:10000/default')
77+
78+
# SQLAlchemy < 2.0
7779
logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True)
7880
print select([func.count('*')], from_obj=logs).scalar()
7981
@@ -82,6 +84,20 @@ First install this package to register it with SQLAlchemy (see ``setup.py``).
8284
logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True)
8385
print select([func.count('*')], from_obj=logs).scalar()
8486
87+
# SQLAlchemy >= 2.0
88+
metadata_obj = MetaData()
89+
books = Table("books", metadata_obj, Column("id", Integer), Column("title", String), Column("primary_author", String))
90+
metadata_obj.create_all(engine)
91+
inspector = inspect(engine)
92+
inspector.get_columns('books')
93+
94+
with engine.connect() as con:
95+
data = [{ "id": 1, "title": "The Hobbit", "primary_author": "Tolkien" },
96+
{ "id": 2, "title": "The Silmarillion", "primary_author": "Tolkien" }]
97+
con.execute(books.insert(), data[0])
98+
result = con.execute(text("select * from books"))
99+
print(result.fetchall())
100+
85101
Note: query generation functionality is not exhaustive or fully tested, but there should be no
86102
problem with raw SQL.
87103

pyhive/sqlalchemy_hive.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,22 @@
1313

1414
import re
1515
from sqlalchemy import exc
16-
from sqlalchemy import processors
16+
from sqlalchemy.sql import text
17+
try:
18+
from sqlalchemy import processors
19+
except ImportError:
20+
# Required for SQLAlchemy>=2.0
21+
from sqlalchemy.engine import processors
1722
from sqlalchemy import types
1823
from sqlalchemy import util
1924
# TODO shouldn't use mysql type
20-
from sqlalchemy.databases import mysql
25+
try:
26+
from sqlalchemy.databases import mysql
27+
mysql_tinyinteger = mysql.MSTinyInteger
28+
except ImportError:
29+
# Required for SQLAlchemy>2.0
30+
from sqlalchemy.dialects import mysql
31+
mysql_tinyinteger = mysql.base.MSTinyInteger
2132
from sqlalchemy.engine import default
2233
from sqlalchemy.sql import compiler
2334
from sqlalchemy.sql.compiler import SQLCompiler
@@ -121,7 +132,7 @@ def __init__(self, dialect):
121132

122133
_type_map = {
123134
'boolean': types.Boolean,
124-
'tinyint': mysql.MSTinyInteger,
135+
'tinyint': mysql_tinyinteger,
125136
'smallint': types.SmallInteger,
126137
'int': types.Integer,
127138
'bigint': types.BigInteger,
@@ -247,10 +258,15 @@ class HiveDialect(default.DefaultDialect):
247258
supports_multivalues_insert = True
248259
type_compiler = HiveTypeCompiler
249260
supports_sane_rowcount = False
261+
supports_statement_cache = False
250262

251263
@classmethod
252264
def dbapi(cls):
253265
return hive
266+
267+
@classmethod
268+
def import_dbapi(cls):
269+
return hive
254270

255271
def create_connect_args(self, url):
256272
kwargs = {
@@ -265,7 +281,7 @@ def create_connect_args(self, url):
265281

266282
def get_schema_names(self, connection, **kw):
267283
# Equivalent to SHOW DATABASES
268-
return [row[0] for row in connection.execute('SHOW SCHEMAS')]
284+
return [row[0] for row in connection.execute(text('SHOW SCHEMAS'))]
269285

270286
def get_view_names(self, connection, schema=None, **kw):
271287
# Hive does not provide functionality to query tableType
@@ -280,7 +296,7 @@ def _get_table_columns(self, connection, table_name, schema):
280296
# Using DESCRIBE works but is uglier.
281297
try:
282298
# This needs the table name to be unescaped (no backticks).
283-
rows = connection.execute('DESCRIBE {}'.format(full_table)).fetchall()
299+
rows = connection.execute(text('DESCRIBE {}'.format(full_table))).fetchall()
284300
except exc.OperationalError as e:
285301
# Does the table exist?
286302
regex_fmt = r'TExecuteStatementResp.*SemanticException.*Table not found {}'
@@ -296,7 +312,7 @@ def _get_table_columns(self, connection, table_name, schema):
296312
raise exc.NoSuchTableError(full_table)
297313
return rows
298314

299-
def has_table(self, connection, table_name, schema=None):
315+
def has_table(self, connection, table_name, schema=None, **kw):
300316
try:
301317
self._get_table_columns(connection, table_name, schema)
302318
return True
@@ -361,7 +377,7 @@ def get_table_names(self, connection, schema=None, **kw):
361377
query = 'SHOW TABLES'
362378
if schema:
363379
query += ' IN ' + self.identifier_preparer.quote_identifier(schema)
364-
return [row[0] for row in connection.execute(query)]
380+
return [row[0] for row in connection.execute(text(query))]
365381

366382
def do_rollback(self, dbapi_connection):
367383
# No transactions for Hive

pyhive/sqlalchemy_presto.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,27 @@
99
from __future__ import unicode_literals
1010

1111
import re
12+
import sqlalchemy
1213
from sqlalchemy import exc
1314
from sqlalchemy import types
1415
from sqlalchemy import util
1516
# TODO shouldn't use mysql type
16-
from sqlalchemy.databases import mysql
17+
from sqlalchemy.sql import text
18+
try:
19+
from sqlalchemy.databases import mysql
20+
mysql_tinyinteger = mysql.MSTinyInteger
21+
except ImportError:
22+
# Required for SQLAlchemy>=2.0
23+
from sqlalchemy.dialects import mysql
24+
mysql_tinyinteger = mysql.base.MSTinyInteger
1725
from sqlalchemy.engine import default
1826
from sqlalchemy.sql import compiler
1927
from sqlalchemy.sql.compiler import SQLCompiler
2028

2129
from pyhive import presto
2230
from pyhive.common import UniversalSet
2331

32+
sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1))
2433

2534
class PrestoIdentifierPreparer(compiler.IdentifierPreparer):
2635
# Just quote everything to make things simpler / easier to upgrade
@@ -29,7 +38,7 @@ class PrestoIdentifierPreparer(compiler.IdentifierPreparer):
2938

3039
_type_map = {
3140
'boolean': types.Boolean,
32-
'tinyint': mysql.MSTinyInteger,
41+
'tinyint': mysql_tinyinteger,
3342
'smallint': types.SmallInteger,
3443
'integer': types.Integer,
3544
'bigint': types.BigInteger,
@@ -80,6 +89,7 @@ class PrestoDialect(default.DefaultDialect):
8089
supports_multivalues_insert = True
8190
supports_unicode_statements = True
8291
supports_unicode_binds = True
92+
supports_statement_cache = False
8393
returns_unicode_strings = True
8494
description_encoding = None
8595
supports_native_boolean = True
@@ -88,6 +98,10 @@ class PrestoDialect(default.DefaultDialect):
8898
@classmethod
8999
def dbapi(cls):
90100
return presto
101+
102+
@classmethod
103+
def import_dbapi(cls):
104+
return presto
91105

92106
def create_connect_args(self, url):
93107
db_parts = (url.database or 'hive').split('/')
@@ -108,14 +122,14 @@ def create_connect_args(self, url):
108122
return [], kwargs
109123

110124
def get_schema_names(self, connection, **kw):
111-
return [row.Schema for row in connection.execute('SHOW SCHEMAS')]
125+
return [row.Schema for row in connection.execute(text('SHOW SCHEMAS'))]
112126

113127
def _get_table_columns(self, connection, table_name, schema):
114128
full_table = self.identifier_preparer.quote_identifier(table_name)
115129
if schema:
116130
full_table = self.identifier_preparer.quote_identifier(schema) + '.' + full_table
117131
try:
118-
return connection.execute('SHOW COLUMNS FROM {}'.format(full_table))
132+
return connection.execute(text('SHOW COLUMNS FROM {}'.format(full_table)))
119133
except (presto.DatabaseError, exc.DatabaseError) as e:
120134
# Normally SQLAlchemy should wrap this exception in sqlalchemy.exc.DatabaseError, which
121135
# it successfully does in the Hive version. The difference with Presto is that this
@@ -134,7 +148,7 @@ def _get_table_columns(self, connection, table_name, schema):
134148
else:
135149
raise
136150

137-
def has_table(self, connection, table_name, schema=None):
151+
def has_table(self, connection, table_name, schema=None, **kw):
138152
try:
139153
self._get_table_columns(connection, table_name, schema)
140154
return True
@@ -176,6 +190,8 @@ def get_indexes(self, connection, table_name, schema=None, **kw):
176190
# - a boolean column named "Partition Key"
177191
# - a string in the "Comment" column
178192
# - a string in the "Extra" column
193+
if sqlalchemy_version >= 1.4:
194+
row = row._mapping
179195
is_partition_key = (
180196
(part_key in row and row[part_key])
181197
or row['Comment'].startswith(part_key)
@@ -192,7 +208,7 @@ def get_table_names(self, connection, schema=None, **kw):
192208
query = 'SHOW TABLES'
193209
if schema:
194210
query += ' FROM ' + self.identifier_preparer.quote_identifier(schema)
195-
return [row.Table for row in connection.execute(query)]
211+
return [row.Table for row in connection.execute(text(query))]
196212

197213
def do_rollback(self, dbapi_connection):
198214
# No transactions for Presto

pyhive/sqlalchemy_trino.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313
from sqlalchemy import types
1414
from sqlalchemy import util
1515
# TODO shouldn't use mysql type
16-
from sqlalchemy.databases import mysql
16+
try:
17+
from sqlalchemy.databases import mysql
18+
mysql_tinyinteger = mysql.MSTinyInteger
19+
except ImportError:
20+
# Required for SQLAlchemy>=2.0
21+
from sqlalchemy.dialects import mysql
22+
mysql_tinyinteger = mysql.base.MSTinyInteger
1723
from sqlalchemy.engine import default
1824
from sqlalchemy.sql import compiler
1925
from sqlalchemy.sql.compiler import SQLCompiler
@@ -28,7 +34,7 @@ class TrinoIdentifierPreparer(PrestoIdentifierPreparer):
2834

2935
_type_map = {
3036
'boolean': types.Boolean,
31-
'tinyint': mysql.MSTinyInteger,
37+
'tinyint': mysql_tinyinteger,
3238
'smallint': types.SmallInteger,
3339
'integer': types.Integer,
3440
'bigint': types.BigInteger,
@@ -71,3 +77,7 @@ class TrinoDialect(PrestoDialect):
7177
@classmethod
7278
def dbapi(cls):
7379
return trino
80+
81+
@classmethod
82+
def import_dbapi(cls):
83+
return trino

0 commit comments

Comments
 (0)