Skip to content

Support presto Decimal, Date and Timestamp types #171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions pyhive/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pyhive.common import DBAPITypeObject
# Make all exceptions visible in this module per DB-API
from pyhive.exc import * # noqa
import re
import base64
import getpass
import logging
Expand All @@ -35,8 +36,8 @@


def connect(*args, **kwargs):
"""Constructor for creating a connection to the database. See class :py:class:`Connection` for
arguments.
"""Constructor for creating a connection to the database.
See class :py:class:`Connection` for arguments.

:returns: a :py:class:`Connection` object.
"""
Expand Down Expand Up @@ -163,11 +164,24 @@ def description(self):
)
if self._columns is None:
return None
return [
# name, type_code, display_size, internal_size, precision, scale, null_ok
(col['name'], col['type'], None, None, None, None, True)
for col in self._columns
]

def _col_process(col_name, col_type):
"""Strip scale and precision from decimal col_type. This allows `decimal` to
be mapped thru sqlalchemy to Decimal type.
:param col_name: string - column name
:param col_type: string - column type
:return: name, type_code, display_size, internal_size, precision, scale, null_ok
"""
if col_type.startswith('decimal'):
m = re.search(r'(\d+),(\d+)', col_type)
precision = int(m.group(1))
scale = int(m.group(2))
col_type = 'decimal'
else:
precision = scale = None
return col_name, col_type, None, None, precision, scale, True

return [_col_process(col['name'], col['type']) for col in self._columns]

def execute(self, operation, parameters=None):
"""Prepare and execute a database operation (query or command).
Expand Down
67 changes: 62 additions & 5 deletions pyhive/sqlalchemy_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from __future__ import unicode_literals

import re
import datetime
import decimal
from sqlalchemy import exc
from sqlalchemy import processors
from sqlalchemy import types
from sqlalchemy import util
# TODO shouldn't use mysql type
Expand All @@ -27,6 +30,50 @@ class PrestoIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = UniversalSet()


class PrestoStringTypeBase(types.TypeDecorator):
"""Translates strings returned by Thrift into something else"""
impl = types.String

def process_bind_param(self, value, dialect):
raise NotImplementedError("Writing to Hive not supported")


class PrestoDate(PrestoStringTypeBase):
"""Translates date strings to date objects"""
impl = types.DATE

def process_result_value(self, value, dialect):
return processors.str_to_date(value)


class PrestoTimestamp(PrestoStringTypeBase):
"""Translates timestamp strings to datetime objects"""
impl = types.TIMESTAMP
DATETIME_RE = re.compile(r"(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d+))?")

def process_result_value(self, value, dialect):
if value is None:
return None

# preserves millisecond - sqlalchemy erroneously converts to microseconds
m = self.DATETIME_RE.match(value)
grps = list(m.groups(0))
mc = grps[6]
mc += '0' * (6 - len(mc))
grps[6] = mc
return datetime.datetime(*map(int, grps))


class PrestoDecimal(PrestoStringTypeBase):
"""Translates strings to decimals"""

def process_result_value(self, value, dialect):
if value is None:
return None
else:
return decimal.Decimal(value)


_type_map = {
'boolean': types.Boolean,
'tinyint': mysql.MSTinyInteger,
Expand All @@ -36,9 +83,10 @@ class PrestoIdentifierPreparer(compiler.IdentifierPreparer):
'real': types.Float,
'double': types.Float,
'varchar': types.String,
'timestamp': types.TIMESTAMP,
'date': types.DATE,
'timestamp': PrestoTimestamp,
'date': PrestoDate,
'varbinary': types.VARBINARY,
'decimal': PrestoDecimal
}


Expand Down Expand Up @@ -81,6 +129,11 @@ class PrestoDialect(default.DefaultDialect):
returns_unicode_strings = True
description_encoding = None
supports_native_boolean = True
dbapi_type_map = {
'date': PrestoDate(),
'timestamp': PrestoTimestamp(),
'decimal': PrestoDecimal()
}
type_compiler = PrestoTypeCompiler

@classmethod
Expand Down Expand Up @@ -144,7 +197,11 @@ def get_columns(self, connection, table_name, schema=None, **kw):
result = []
for row in rows:
try:
coltype = _type_map[row.Type]
# Take out the more detailed type information
# e.g. 'map<int,int>' -> 'map'
# 'decimal(10,1)' -> decimal
col_type = re.search(r'^\w+', row.Type).group(0)
coltype = _type_map[col_type]
except KeyError:
util.warn("Did not recognize type '%s' of column '%s'" % (row.Type, row.Column))
coltype = types.NullType
Expand All @@ -158,11 +215,11 @@ def get_columns(self, connection, table_name, schema=None, **kw):
return result

def get_foreign_keys(self, connection, table_name, schema=None, **kw):
# Hive has no support for foreign keys.
# Presto has no support for foreign keys.
return []

def get_pk_constraint(self, connection, table_name, schema=None, **kw):
# Hive has no support for primary keys.
# Presto has no support for primary keys.
return []

def get_indexes(self, connection, table_name, schema=None, **kw):
Expand Down
2 changes: 2 additions & 0 deletions pyhive/tests/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def test_complex(self, cursor):
('one_row_complex.double', 'DOUBLE_TYPE', None, None, None, None, True),
('one_row_complex.string', 'STRING_TYPE', None, None, None, None, True),
('one_row_complex.timestamp', 'TIMESTAMP_TYPE', None, None, None, None, True),
('one_row_complex.date', 'DATE_TYPE', None, None, None, None, True),
('one_row_complex.binary', 'BINARY_TYPE', None, None, None, None, True),
('one_row_complex.array', 'ARRAY_TYPE', None, None, None, None, True),
('one_row_complex.map', 'MAP_TYPE', None, None, None, None, True),
Expand All @@ -73,6 +74,7 @@ def test_complex(self, cursor):
0.25,
'a string',
'1970-01-01 00:00:00.0',
'1970-02-01',
b'123',
'[1,2]',
'{1:2,3:4}',
Expand Down
4 changes: 3 additions & 1 deletion pyhive/tests/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,13 @@ def test_complex(self, cursor):
('double', 'double', None, None, None, None, True),
('string', 'varchar', None, None, None, None, True),
('timestamp', 'timestamp', None, None, None, None, True),
('date', 'date', None, None, None, None, True),
('binary', 'varbinary', None, None, None, None, True),
('array', 'array(integer)', None, None, None, None, True),
('map', 'map(integer,integer)', None, None, None, None, True),
('struct', 'row(a integer,b integer)', None, None, None, None, True),
# ('union', 'varchar', None, None, None, None, True),
('decimal', 'decimal(10,1)', None, None, None, None, True),
('decimal', 'decimal', None, None, 10, 1, True),
])
rows = cursor.fetchall()
expected = [(
Expand All @@ -78,6 +79,7 @@ def test_complex(self, cursor):
0.25,
'a string',
'1970-01-01 00:00:00.000',
'1970-02-01',
b'123',
[1, 2],
{"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON
Expand Down
4 changes: 3 additions & 1 deletion pyhive/tests/test_sqlalchemy_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
0.25,
'a string',
datetime.datetime(1970, 1, 1),
datetime.date(1970, 2, 1),
b'123',
'[1,2]',
'{1:2,3:4}',
Expand Down Expand Up @@ -68,7 +69,7 @@ def test_dotted_column_names_raw(self, engine, connection):
def test_reflect_select(self, engine, connection):
"""reflecttable should be able to fill in a table from the name"""
one_row_complex = Table('one_row_complex', MetaData(bind=engine), autoload=True)
self.assertEqual(len(one_row_complex.c), 15)
self.assertEqual(len(one_row_complex.c), 16)
self.assertIsInstance(one_row_complex.c.string, Column)
rows = one_row_complex.select().execute().fetchall()
self.assertEqual(len(rows), 1)
Expand All @@ -84,6 +85,7 @@ def test_reflect_select(self, engine, connection):
self.assertIsInstance(one_row_complex.c.double.type, types.Float)
self.assertIsInstance(one_row_complex.c.string.type, types.String)
self.assertIsInstance(one_row_complex.c.timestamp.type, HiveTimestamp)
self.assertIsInstance(one_row_complex.c.date.type, HiveDate)
self.assertIsInstance(one_row_complex.c.binary.type, types.String)
self.assertIsInstance(one_row_complex.c.array.type, types.String)
self.assertIsInstance(one_row_complex.c.map.type, types.String)
Expand Down
15 changes: 10 additions & 5 deletions pyhive/tests/test_sqlalchemy_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from builtins import str
from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase
from pyhive.tests.sqlalchemy_test_case import with_engine_connection
from pyhive.sqlalchemy_presto import PrestoTimestamp, PrestoDate, PrestoDecimal
from sqlalchemy import types
from sqlalchemy.engine import create_engine
from sqlalchemy.schema import Column
Expand All @@ -12,6 +13,8 @@

import contextlib
import unittest
import datetime
import decimal


class TestSqlAlchemyPresto(unittest.TestCase, SqlAlchemyTestCase):
Expand All @@ -29,7 +32,7 @@ def test_reflect_select(self, engine, connection):
"""reflecttable should be able to fill in a table from the name"""
one_row_complex = Table('one_row_complex', MetaData(bind=engine), autoload=True)
# Presto ignores the union column
self.assertEqual(len(one_row_complex.c), 15 - 1)
self.assertEqual(len(one_row_complex.c), 16 - 1)
self.assertIsInstance(one_row_complex.c.string, Column)
rows = one_row_complex.select().execute().fetchall()
self.assertEqual(len(rows), 1)
Expand All @@ -42,13 +45,14 @@ def test_reflect_select(self, engine, connection):
0.5,
0.25,
'a string',
'1970-01-01 00:00:00.000',
datetime.datetime(1970, 1, 1),
datetime.date(1970, 2, 1),
b'123',
[1, 2],
{"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON
[1, 2], # struct is returned as a list of elements
# '{0:1}',
'0.1',
decimal.Decimal('0.1'),
])

# TODO some of these types could be filled in better
Expand All @@ -60,12 +64,13 @@ def test_reflect_select(self, engine, connection):
self.assertIsInstance(one_row_complex.c.float.type, types.Float)
self.assertIsInstance(one_row_complex.c.double.type, types.Float)
self.assertIsInstance(one_row_complex.c.string.type, String)
self.assertIsInstance(one_row_complex.c.timestamp.type, types.TIMESTAMP)
self.assertIsInstance(one_row_complex.c.timestamp.type, PrestoTimestamp)
self.assertIsInstance(one_row_complex.c.date.type, PrestoDate)
self.assertIsInstance(one_row_complex.c.binary.type, types.VARBINARY)
self.assertIsInstance(one_row_complex.c.array.type, types.NullType)
self.assertIsInstance(one_row_complex.c.map.type, types.NullType)
self.assertIsInstance(one_row_complex.c.struct.type, types.NullType)
self.assertIsInstance(one_row_complex.c.decimal.type, types.NullType)
self.assertIsInstance(one_row_complex.c.decimal.type, PrestoDecimal)

def test_url_default(self):
engine = create_engine('presto://localhost:8080/hive')
Expand Down
3 changes: 3 additions & 0 deletions scripts/make_one_row_complex.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ COLUMNS='
`double` DOUBLE,
`string` STRING,
`timestamp` TIMESTAMP,
`date` DATE,
`binary` BINARY,
`array` ARRAY<int>,
`map` MAP<int, int>,
Expand All @@ -34,6 +35,7 @@ INSERT OVERWRITE TABLE one_row_complex SELECT
0.25,
'a string',
0,
cast('1970-02-01' AS date),
'123',
array(1, 2),
map(1, 2, 3, 4),
Expand All @@ -52,6 +54,7 @@ INSERT OVERWRITE TABLE one_row_complex_null SELECT
null,
null,
null,
null,
IF(false, array(1, 2), null),
IF(false, map(1, 2, 3, 4), null),
IF(false, named_struct('a', 1, 'b', 2), null),
Expand Down