diff --git a/tarantool/connection.py b/tarantool/connection.py index 6f330896..358aadd7 100644 --- a/tarantool/connection.py +++ b/tarantool/connection.py @@ -91,7 +91,12 @@ def __init__(self, host, port, connect_now=True, encoding=ENCODING_DEFAULT, call_16=False, - connection_timeout=CONNECTION_TIMEOUT): + connection_timeout=CONNECTION_TIMEOUT, + pack_default=None, + unpack_object_hook=None, + unpack_list_hook=None, + unpack_object_pairs_hook=None, + unpack_ext_hook=None): ''' Initialize a connection to the server. @@ -126,6 +131,14 @@ def __init__(self, host, port, self.encoding = encoding self.call_16 = call_16 self.connection_timeout = connection_timeout + self.pack_default = pack_default + unpack_hooks = { + "object_hook": unpack_object_hook, + "list_hook": unpack_list_hook, + "object_pairs_hook": unpack_object_pairs_hook, + "ext_hook": unpack_ext_hook, + } + self.unpack_hooks = {k: v for (k, v) in unpack_hooks.items() if v is not None} if connect_now: self.connect() diff --git a/tarantool/request.py b/tarantool/request.py index 2e45c8d9..6153b9e2 100644 --- a/tarantool/request.py +++ b/tarantool/request.py @@ -64,6 +64,7 @@ def __init__(self, conn): self.conn = conn self._sync = None self._body = '' + self.pack_default = getattr(conn, "pack_default", None) def __bytes__(self): return self.header(len(self._body)) + self._body @@ -88,6 +89,10 @@ def header(self, length): return msgpack.dumps(length + len(header)) + header + def msgpack_dumps(self, obj): + return msgpack.dumps(obj, default=self.pack_default) + + class RequestInsert(Request): ''' @@ -102,7 +107,7 @@ def __init__(self, conn, space_no, values): super(RequestInsert, self).__init__(conn) assert isinstance(values, (tuple, list)) - request_body = msgpack.dumps({IPROTO_SPACE_ID: space_no, + request_body = self.msgpack_dumps({IPROTO_SPACE_ID: space_no, IPROTO_TUPLE: values}) self._body = request_body @@ -131,7 +136,7 @@ def sha1(values): hash2 = sha1((hash1,)) scramble = sha1((salt, hash2)) scramble = strxor(hash1, scramble) - request_body = msgpack.dumps({IPROTO_USER_NAME: user, + request_body = self.msgpack_dumps({IPROTO_USER_NAME: user, IPROTO_TUPLE: ("chap-sha1", scramble)}) self._body = request_body @@ -139,11 +144,11 @@ def header(self, length): self._sync = self.conn.generate_sync() # Set IPROTO_SCHEMA_ID: 0 to avoid SchemaReloadException # It is ok to use 0 in auth every time. - header = msgpack.dumps({IPROTO_CODE: self.request_type, + header = self.msgpack_dumps({IPROTO_CODE: self.request_type, IPROTO_SYNC: self._sync, IPROTO_SCHEMA_ID: 0}) - return msgpack.dumps(length + len(header)) + header + return self.msgpack_dumps(length + len(header)) + header class RequestReplace(Request): @@ -159,7 +164,7 @@ def __init__(self, conn, space_no, values): super(RequestReplace, self).__init__(conn) assert isinstance(values, (tuple, list)) - request_body = msgpack.dumps({IPROTO_SPACE_ID: space_no, + request_body = self.msgpack_dumps({IPROTO_SPACE_ID: space_no, IPROTO_TUPLE: values}) self._body = request_body @@ -177,7 +182,7 @@ def __init__(self, conn, space_no, index_no, key): ''' super(RequestDelete, self).__init__(conn) - request_body = msgpack.dumps({IPROTO_SPACE_ID: space_no, + request_body = self.msgpack_dumps({IPROTO_SPACE_ID: space_no, IPROTO_INDEX_ID: index_no, IPROTO_KEY: key}) @@ -193,7 +198,7 @@ class RequestSelect(Request): # pylint: disable=W0231 def __init__(self, conn, space_no, index_no, key, offset, limit, iterator): super(RequestSelect, self).__init__(conn) - request_body = msgpack.dumps({IPROTO_SPACE_ID: space_no, + request_body = self.msgpack_dumps({IPROTO_SPACE_ID: space_no, IPROTO_INDEX_ID: index_no, IPROTO_OFFSET: offset, IPROTO_LIMIT: limit, @@ -214,7 +219,7 @@ class RequestUpdate(Request): def __init__(self, conn, space_no, index_no, key, op_list): super(RequestUpdate, self).__init__(conn) - request_body = msgpack.dumps({IPROTO_SPACE_ID: space_no, + request_body = self.msgpack_dumps({IPROTO_SPACE_ID: space_no, IPROTO_INDEX_ID: index_no, IPROTO_KEY: key, IPROTO_TUPLE: op_list}) @@ -235,7 +240,7 @@ def __init__(self, conn, name, args, call_16): super(RequestCall, self).__init__(conn) assert isinstance(args, (list, tuple)) - request_body = msgpack.dumps({IPROTO_FUNCTION_NAME: name, + request_body = self.msgpack_dumps({IPROTO_FUNCTION_NAME: name, IPROTO_TUPLE: args}) self._body = request_body @@ -252,7 +257,7 @@ def __init__(self, conn, name, args): super(RequestEval, self).__init__(conn) assert isinstance(args, (list, tuple)) - request_body = msgpack.dumps({IPROTO_EXPR: name, + request_body = self.msgpack_dumps({IPROTO_EXPR: name, IPROTO_TUPLE: args}) self._body = request_body @@ -280,7 +285,7 @@ class RequestUpsert(Request): def __init__(self, conn, space_no, index_no, tuple_value, op_list): super(RequestUpsert, self).__init__(conn) - request_body = msgpack.dumps({IPROTO_SPACE_ID: space_no, + request_body = self.msgpack_dumps({IPROTO_SPACE_ID: space_no, IPROTO_INDEX_ID: index_no, IPROTO_TUPLE: tuple_value, IPROTO_OPS: op_list}) @@ -297,7 +302,7 @@ class RequestJoin(Request): # pylint: disable=W0231 def __init__(self, conn, server_uuid): super(RequestJoin, self).__init__(conn) - request_body = msgpack.dumps({IPROTO_SERVER_UUID: server_uuid}) + request_body = self.msgpack_dumps({IPROTO_SERVER_UUID: server_uuid}) self._body = request_body @@ -312,7 +317,7 @@ def __init__(self, conn, cluster_uuid, server_uuid, vclock): super(RequestSubscribe, self).__init__(conn) assert isinstance(vclock, dict) - request_body = msgpack.dumps({ + request_body = self.msgpack_dumps({ IPROTO_CLUSTER_UUID: cluster_uuid, IPROTO_SERVER_UUID: server_uuid, IPROTO_VCLOCK: vclock @@ -329,6 +334,6 @@ class RequestOK(Request): # pylint: disable=W0231 def __init__(self, conn, sync): super(RequestOK, self).__init__(conn) - request_body = msgpack.dumps({IPROTO_CODE: self.request_type, + request_body = self.msgpack_dumps({IPROTO_CODE: self.request_type, IPROTO_SYNC: sync}) self._body = request_body diff --git a/tarantool/response.py b/tarantool/response.py index 9516cd39..ddc52cb8 100644 --- a/tarantool/response.py +++ b/tarantool/response.py @@ -46,6 +46,8 @@ def __init__(self, conn, response): :type body: array of bytes ''' + unpack_kwargs = getattr(conn, "unpack_hooks", {}) + # This is not necessary, because underlying list data structures are # created in the __new__(). # super(Response, self).__init__() @@ -54,11 +56,11 @@ def __init__(self, conn, response): # Get rid of the following warning. # > PendingDeprecationWarning: encoding is deprecated, # > Use raw=False instead. - unpacker = msgpack.Unpacker(use_list=True, raw=False) + unpacker = msgpack.Unpacker(use_list=True, raw=False, **unpack_kwargs) elif conn.encoding is not None: - unpacker = msgpack.Unpacker(use_list=True, encoding=conn.encoding) + unpacker = msgpack.Unpacker(use_list=True, encoding=conn.encoding, **unpack_kwargs) else: - unpacker = msgpack.Unpacker(use_list=True) + unpacker = msgpack.Unpacker(use_list=True, **unpack_kwargs) unpacker.feed(response) header = unpacker.unpack() diff --git a/unit/suites/__init__.py b/unit/suites/__init__.py index ead75297..80edba21 100644 --- a/unit/suites/__init__.py +++ b/unit/suites/__init__.py @@ -9,9 +9,11 @@ from .test_protocol import TestSuite_Protocol from .test_reconnect import TestSuite_Reconnect from .test_mesh import TestSuite_Mesh +from .test_hooks import TestSuite_DefaultAndObjectHook test_cases = (TestSuite_Schema, TestSuite_Request, TestSuite_Protocol, - TestSuite_Reconnect, TestSuite_Mesh) + TestSuite_Reconnect, TestSuite_Mesh, + TestSuite_DefaultAndObjectHook,) def load_tests(loader, tests, pattern): suite = unittest.TestSuite() diff --git a/unit/suites/test_hooks.py b/unit/suites/test_hooks.py new file mode 100644 index 00000000..51f17078 --- /dev/null +++ b/unit/suites/test_hooks.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- + +import sys +import unittest +from datetime import datetime + +import tarantool +from .lib.tarantool_server import TarantoolServer + + +def object_unpack(obj): + if obj.get("__type__") == "datetime": + return datetime.fromtimestamp(obj['obj']) + return obj + + +def object_pack(obj): + if isinstance(obj, datetime): + return {"__type__": "datetime", "obj": obj.timestamp()} + return obj + + +def list_unpack(lst): + if lst and lst[0] == "_my_datetime": + return datetime.fromtimestamp(lst[1]) + return lst + + +def list_pack(obj): + if isinstance(obj, datetime): + return ["_my_datetime", obj.timestamp()] + return obj + + +class TestSuite_DefaultAndObjectHook(unittest.TestCase): + @classmethod + def setUpClass(self): + print(' PACK/UNPACK HOOKs '.center(70, '='), file=sys.stderr) + print('-' * 70, file=sys.stderr) + self.srv = TarantoolServer() + self.srv.script = 'unit/suites/box.lua' + self.srv.start() + self.srv.admin.execute("simple_return = function(a) return a end") + self.srv.admin.execute( + "box.schema.user.grant('guest','execute','universe')") + + def test_00_not_set(self): + con = tarantool.Connection(self.srv.host, self.srv.args['primary']) + + ret = con.call("simple_return", + {"__type__": "datetime", "obj": 1546300800}) + self.assertDictEqual(ret._data[0], + {"__type__": "datetime", "obj": 1546300800}) + + with self.assertRaises(TypeError): + con.call("simple_return", datetime.fromtimestamp(1546300800)) + + def test_01_set_default(self): + con = tarantool.Connection(self.srv.host, self.srv.args['primary'], + pack_default=object_pack) + + ret = con.call("simple_return", datetime.fromtimestamp(1546300800)) + self.assertDictEqual(ret._data[0], + {"__type__": "datetime", "obj": 1546300800}) + + def test_02_set_object_hook(self): + con = tarantool.Connection(self.srv.host, self.srv.args['primary'], + unpack_object_hook=object_unpack) + + ret = con.call("simple_return", + {"__type__": "datetime", "obj": 1546300800}) + self.assertEqual(ret._data[0], datetime.fromtimestamp(1546300800)) + + ret = con.call("simple_return", + {"__type__": "1datetime", "obj": 1546300800}) + self.assertEqual(ret._data[0], + {"__type__": "1datetime", "obj": 1546300800}) + + def test_03_set_object_hook_and_default(self): + dt = datetime.fromtimestamp(1546300800) + con = tarantool.Connection(self.srv.host, self.srv.args['primary'], + unpack_object_hook=object_unpack, pack_default=object_pack) + ret = con.call("simple_return", dt) + self.assertEqual(ret._data[0], dt) + + + def test_04_set_list_hook_and_default(self): + dt = datetime.fromtimestamp(1546300800) + con = tarantool.Connection(self.srv.host, self.srv.args['primary'], + unpack_list_hook=list_unpack, pack_default=list_pack) + ret = con.call("simple_return", dt) + self.assertEqual(ret._data[0], dt)