Skip to content

Msgpack hooks possibility #154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion tarantool/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,12 @@ def __init__(self, host, port,
connect_now=True,
encoding=ENCODING_DEFAULT,
call_16=False,
connection_timeout=CONNECTION_TIMEOUT):
connection_timeout=CONNECTION_TIMEOUT,
pack_default=None,
unpack_object_hook=None,
unpack_list_hook=None,
unpack_object_pairs_hook=None,
unpack_ext_hook=None):
'''
Initialize a connection to the server.

Expand Down Expand Up @@ -126,6 +131,14 @@ def __init__(self, host, port,
self.encoding = encoding
self.call_16 = call_16
self.connection_timeout = connection_timeout
self.pack_default = pack_default
unpack_hooks = {
"object_hook": unpack_object_hook,
"list_hook": unpack_list_hook,
"object_pairs_hook": unpack_object_pairs_hook,
"ext_hook": unpack_ext_hook,
}
self.unpack_hooks = {k: v for (k, v) in unpack_hooks.items() if v is not None}
if connect_now:
self.connect()

Expand Down
33 changes: 19 additions & 14 deletions tarantool/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(self, conn):
self.conn = conn
self._sync = None
self._body = ''
self.pack_default = getattr(conn, "pack_default", None)

def __bytes__(self):
return self.header(len(self._body)) + self._body
Expand All @@ -88,6 +89,10 @@ def header(self, length):

return msgpack.dumps(length + len(header)) + header

def msgpack_dumps(self, obj):
return msgpack.dumps(obj, default=self.pack_default)



class RequestInsert(Request):
'''
Expand All @@ -102,7 +107,7 @@ def __init__(self, conn, space_no, values):
super(RequestInsert, self).__init__(conn)
assert isinstance(values, (tuple, list))

request_body = msgpack.dumps({IPROTO_SPACE_ID: space_no,
request_body = self.msgpack_dumps({IPROTO_SPACE_ID: space_no,
IPROTO_TUPLE: values})

self._body = request_body
Expand Down Expand Up @@ -131,19 +136,19 @@ def sha1(values):
hash2 = sha1((hash1,))
scramble = sha1((salt, hash2))
scramble = strxor(hash1, scramble)
request_body = msgpack.dumps({IPROTO_USER_NAME: user,
request_body = self.msgpack_dumps({IPROTO_USER_NAME: user,
IPROTO_TUPLE: ("chap-sha1", scramble)})
self._body = request_body

def header(self, length):
self._sync = self.conn.generate_sync()
# Set IPROTO_SCHEMA_ID: 0 to avoid SchemaReloadException
# It is ok to use 0 in auth every time.
header = msgpack.dumps({IPROTO_CODE: self.request_type,
header = self.msgpack_dumps({IPROTO_CODE: self.request_type,
IPROTO_SYNC: self._sync,
IPROTO_SCHEMA_ID: 0})

return msgpack.dumps(length + len(header)) + header
return self.msgpack_dumps(length + len(header)) + header


class RequestReplace(Request):
Expand All @@ -159,7 +164,7 @@ def __init__(self, conn, space_no, values):
super(RequestReplace, self).__init__(conn)
assert isinstance(values, (tuple, list))

request_body = msgpack.dumps({IPROTO_SPACE_ID: space_no,
request_body = self.msgpack_dumps({IPROTO_SPACE_ID: space_no,
IPROTO_TUPLE: values})

self._body = request_body
Expand All @@ -177,7 +182,7 @@ def __init__(self, conn, space_no, index_no, key):
'''
super(RequestDelete, self).__init__(conn)

request_body = msgpack.dumps({IPROTO_SPACE_ID: space_no,
request_body = self.msgpack_dumps({IPROTO_SPACE_ID: space_no,
IPROTO_INDEX_ID: index_no,
IPROTO_KEY: key})

Expand All @@ -193,7 +198,7 @@ class RequestSelect(Request):
# pylint: disable=W0231
def __init__(self, conn, space_no, index_no, key, offset, limit, iterator):
super(RequestSelect, self).__init__(conn)
request_body = msgpack.dumps({IPROTO_SPACE_ID: space_no,
request_body = self.msgpack_dumps({IPROTO_SPACE_ID: space_no,
IPROTO_INDEX_ID: index_no,
IPROTO_OFFSET: offset,
IPROTO_LIMIT: limit,
Expand All @@ -214,7 +219,7 @@ class RequestUpdate(Request):
def __init__(self, conn, space_no, index_no, key, op_list):
super(RequestUpdate, self).__init__(conn)

request_body = msgpack.dumps({IPROTO_SPACE_ID: space_no,
request_body = self.msgpack_dumps({IPROTO_SPACE_ID: space_no,
IPROTO_INDEX_ID: index_no,
IPROTO_KEY: key,
IPROTO_TUPLE: op_list})
Expand All @@ -235,7 +240,7 @@ def __init__(self, conn, name, args, call_16):
super(RequestCall, self).__init__(conn)
assert isinstance(args, (list, tuple))

request_body = msgpack.dumps({IPROTO_FUNCTION_NAME: name,
request_body = self.msgpack_dumps({IPROTO_FUNCTION_NAME: name,
IPROTO_TUPLE: args})

self._body = request_body
Expand All @@ -252,7 +257,7 @@ def __init__(self, conn, name, args):
super(RequestEval, self).__init__(conn)
assert isinstance(args, (list, tuple))

request_body = msgpack.dumps({IPROTO_EXPR: name,
request_body = self.msgpack_dumps({IPROTO_EXPR: name,
IPROTO_TUPLE: args})

self._body = request_body
Expand Down Expand Up @@ -280,7 +285,7 @@ class RequestUpsert(Request):
def __init__(self, conn, space_no, index_no, tuple_value, op_list):
super(RequestUpsert, self).__init__(conn)

request_body = msgpack.dumps({IPROTO_SPACE_ID: space_no,
request_body = self.msgpack_dumps({IPROTO_SPACE_ID: space_no,
IPROTO_INDEX_ID: index_no,
IPROTO_TUPLE: tuple_value,
IPROTO_OPS: op_list})
Expand All @@ -297,7 +302,7 @@ class RequestJoin(Request):
# pylint: disable=W0231
def __init__(self, conn, server_uuid):
super(RequestJoin, self).__init__(conn)
request_body = msgpack.dumps({IPROTO_SERVER_UUID: server_uuid})
request_body = self.msgpack_dumps({IPROTO_SERVER_UUID: server_uuid})
self._body = request_body


Expand All @@ -312,7 +317,7 @@ def __init__(self, conn, cluster_uuid, server_uuid, vclock):
super(RequestSubscribe, self).__init__(conn)
assert isinstance(vclock, dict)

request_body = msgpack.dumps({
request_body = self.msgpack_dumps({
IPROTO_CLUSTER_UUID: cluster_uuid,
IPROTO_SERVER_UUID: server_uuid,
IPROTO_VCLOCK: vclock
Expand All @@ -329,6 +334,6 @@ class RequestOK(Request):
# pylint: disable=W0231
def __init__(self, conn, sync):
super(RequestOK, self).__init__(conn)
request_body = msgpack.dumps({IPROTO_CODE: self.request_type,
request_body = self.msgpack_dumps({IPROTO_CODE: self.request_type,
IPROTO_SYNC: sync})
self._body = request_body
8 changes: 5 additions & 3 deletions tarantool/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def __init__(self, conn, response):
:type body: array of bytes
'''

unpack_kwargs = getattr(conn, "unpack_hooks", {})

# This is not necessary, because underlying list data structures are
# created in the __new__().
# super(Response, self).__init__()
Expand All @@ -54,11 +56,11 @@ def __init__(self, conn, response):
# Get rid of the following warning.
# > PendingDeprecationWarning: encoding is deprecated,
# > Use raw=False instead.
unpacker = msgpack.Unpacker(use_list=True, raw=False)
unpacker = msgpack.Unpacker(use_list=True, raw=False, **unpack_kwargs)
elif conn.encoding is not None:
unpacker = msgpack.Unpacker(use_list=True, encoding=conn.encoding)
unpacker = msgpack.Unpacker(use_list=True, encoding=conn.encoding, **unpack_kwargs)
else:
unpacker = msgpack.Unpacker(use_list=True)
unpacker = msgpack.Unpacker(use_list=True, **unpack_kwargs)

unpacker.feed(response)
header = unpacker.unpack()
Expand Down
4 changes: 3 additions & 1 deletion unit/suites/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from .test_protocol import TestSuite_Protocol
from .test_reconnect import TestSuite_Reconnect
from .test_mesh import TestSuite_Mesh
from .test_hooks import TestSuite_DefaultAndObjectHook

test_cases = (TestSuite_Schema, TestSuite_Request, TestSuite_Protocol,
TestSuite_Reconnect, TestSuite_Mesh)
TestSuite_Reconnect, TestSuite_Mesh,
TestSuite_DefaultAndObjectHook,)

def load_tests(loader, tests, pattern):
suite = unittest.TestSuite()
Expand Down
92 changes: 92 additions & 0 deletions unit/suites/test_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-

import sys
import unittest
from datetime import datetime

import tarantool
from .lib.tarantool_server import TarantoolServer


def object_unpack(obj):
if obj.get("__type__") == "datetime":
return datetime.fromtimestamp(obj['obj'])
return obj


def object_pack(obj):
if isinstance(obj, datetime):
return {"__type__": "datetime", "obj": obj.timestamp()}
return obj


def list_unpack(lst):
if lst and lst[0] == "_my_datetime":
return datetime.fromtimestamp(lst[1])
return lst


def list_pack(obj):
if isinstance(obj, datetime):
return ["_my_datetime", obj.timestamp()]
return obj


class TestSuite_DefaultAndObjectHook(unittest.TestCase):
@classmethod
def setUpClass(self):
print(' PACK/UNPACK HOOKs '.center(70, '='), file=sys.stderr)
print('-' * 70, file=sys.stderr)
self.srv = TarantoolServer()
self.srv.script = 'unit/suites/box.lua'
self.srv.start()
self.srv.admin.execute("simple_return = function(a) return a end")
self.srv.admin.execute(
"box.schema.user.grant('guest','execute','universe')")

def test_00_not_set(self):
con = tarantool.Connection(self.srv.host, self.srv.args['primary'])

ret = con.call("simple_return",
{"__type__": "datetime", "obj": 1546300800})
self.assertDictEqual(ret._data[0],
{"__type__": "datetime", "obj": 1546300800})

with self.assertRaises(TypeError):
con.call("simple_return", datetime.fromtimestamp(1546300800))

def test_01_set_default(self):
con = tarantool.Connection(self.srv.host, self.srv.args['primary'],
pack_default=object_pack)

ret = con.call("simple_return", datetime.fromtimestamp(1546300800))
self.assertDictEqual(ret._data[0],
{"__type__": "datetime", "obj": 1546300800})

def test_02_set_object_hook(self):
con = tarantool.Connection(self.srv.host, self.srv.args['primary'],
unpack_object_hook=object_unpack)

ret = con.call("simple_return",
{"__type__": "datetime", "obj": 1546300800})
self.assertEqual(ret._data[0], datetime.fromtimestamp(1546300800))

ret = con.call("simple_return",
{"__type__": "1datetime", "obj": 1546300800})
self.assertEqual(ret._data[0],
{"__type__": "1datetime", "obj": 1546300800})

def test_03_set_object_hook_and_default(self):
dt = datetime.fromtimestamp(1546300800)
con = tarantool.Connection(self.srv.host, self.srv.args['primary'],
unpack_object_hook=object_unpack, pack_default=object_pack)
ret = con.call("simple_return", dt)
self.assertEqual(ret._data[0], dt)


def test_04_set_list_hook_and_default(self):
dt = datetime.fromtimestamp(1546300800)
con = tarantool.Connection(self.srv.host, self.srv.args['primary'],
unpack_list_hook=list_unpack, pack_default=list_pack)
ret = con.call("simple_return", dt)
self.assertEqual(ret._data[0], dt)