Skip to content

Commit ed38e77

Browse files
authored
Add dialect to ft aggregate (#2537)
* add dialect to aggregate request * added test * format * async test
1 parent f10c81a commit ed38e77

File tree

3 files changed

+131
-81
lines changed

3 files changed

+131
-81
lines changed

redis/commands/search/aggregation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def __init__(self, query="*"):
108108
self._with_schema = False
109109
self._verbatim = False
110110
self._cursor = []
111+
self._dialect = None
111112

112113
def load(self, *fields):
113114
"""
@@ -321,10 +322,22 @@ def build_args(self):
321322
ret.append(str(len(self._loadfields)))
322323
ret.extend(self._loadfields)
323324

325+
if self._dialect:
326+
ret.extend(["DIALECT", self._dialect])
327+
324328
ret.extend(self._aggregateplan)
325329

326330
return ret
327331

332+
def dialect(self, dialect):
333+
"""
334+
Add a dialect field to the aggregate command.
335+
336+
- **dialect** - dialect version to execute the query under
337+
"""
338+
self._dialect = dialect
339+
return self
340+
328341

329342
class Cursor:
330343
def __init__(self, cid):

tests/test_asyncio/test_search.py

Lines changed: 98 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -832,100 +832,127 @@ async def test_aggregations_groupby(modclient: redis.Redis):
832832
},
833833
)
834834

835-
req = aggregations.AggregateRequest("redis").group_by("@parent", reducers.count())
835+
for dialect in [1, 2]:
836+
req = (
837+
aggregations.AggregateRequest("redis")
838+
.group_by("@parent", reducers.count())
839+
.dialect(dialect)
840+
)
836841

837-
res = (await modclient.ft().aggregate(req)).rows[0]
838-
assert res[1] == "redis"
839-
assert res[3] == "3"
842+
res = (await modclient.ft().aggregate(req)).rows[0]
843+
assert res[1] == "redis"
844+
assert res[3] == "3"
840845

841-
req = aggregations.AggregateRequest("redis").group_by(
842-
"@parent", reducers.count_distinct("@title")
843-
)
846+
req = (
847+
aggregations.AggregateRequest("redis")
848+
.group_by("@parent", reducers.count_distinct("@title"))
849+
.dialect(dialect)
850+
)
844851

845-
res = (await modclient.ft().aggregate(req)).rows[0]
846-
assert res[1] == "redis"
847-
assert res[3] == "3"
852+
res = (await modclient.ft().aggregate(req)).rows[0]
853+
assert res[1] == "redis"
854+
assert res[3] == "3"
848855

849-
req = aggregations.AggregateRequest("redis").group_by(
850-
"@parent", reducers.count_distinctish("@title")
851-
)
856+
req = (
857+
aggregations.AggregateRequest("redis")
858+
.group_by("@parent", reducers.count_distinctish("@title"))
859+
.dialect(dialect)
860+
)
852861

853-
res = (await modclient.ft().aggregate(req)).rows[0]
854-
assert res[1] == "redis"
855-
assert res[3] == "3"
862+
res = (await modclient.ft().aggregate(req)).rows[0]
863+
assert res[1] == "redis"
864+
assert res[3] == "3"
856865

857-
req = aggregations.AggregateRequest("redis").group_by(
858-
"@parent", reducers.sum("@random_num")
859-
)
866+
req = (
867+
aggregations.AggregateRequest("redis")
868+
.group_by("@parent", reducers.sum("@random_num"))
869+
.dialect(dialect)
870+
)
860871

861-
res = (await modclient.ft().aggregate(req)).rows[0]
862-
assert res[1] == "redis"
863-
assert res[3] == "21" # 10+8+3
872+
res = (await modclient.ft().aggregate(req)).rows[0]
873+
assert res[1] == "redis"
874+
assert res[3] == "21" # 10+8+3
864875

865-
req = aggregations.AggregateRequest("redis").group_by(
866-
"@parent", reducers.min("@random_num")
867-
)
876+
req = (
877+
aggregations.AggregateRequest("redis")
878+
.group_by("@parent", reducers.min("@random_num"))
879+
.dialect(dialect)
880+
)
868881

869-
res = (await modclient.ft().aggregate(req)).rows[0]
870-
assert res[1] == "redis"
871-
assert res[3] == "3" # min(10,8,3)
882+
res = (await modclient.ft().aggregate(req)).rows[0]
883+
assert res[1] == "redis"
884+
assert res[3] == "3" # min(10,8,3)
872885

873-
req = aggregations.AggregateRequest("redis").group_by(
874-
"@parent", reducers.max("@random_num")
875-
)
886+
req = (
887+
aggregations.AggregateRequest("redis")
888+
.group_by("@parent", reducers.max("@random_num"))
889+
.dialect(dialect)
890+
)
876891

877-
res = (await modclient.ft().aggregate(req)).rows[0]
878-
assert res[1] == "redis"
879-
assert res[3] == "10" # max(10,8,3)
892+
res = (await modclient.ft().aggregate(req)).rows[0]
893+
assert res[1] == "redis"
894+
assert res[3] == "10" # max(10,8,3)
880895

881-
req = aggregations.AggregateRequest("redis").group_by(
882-
"@parent", reducers.avg("@random_num")
883-
)
896+
req = (
897+
aggregations.AggregateRequest("redis")
898+
.group_by("@parent", reducers.avg("@random_num"))
899+
.dialect(dialect)
900+
)
884901

885-
res = (await modclient.ft().aggregate(req)).rows[0]
886-
assert res[1] == "redis"
887-
assert res[3] == "7" # (10+3+8)/3
902+
res = (await modclient.ft().aggregate(req)).rows[0]
903+
assert res[1] == "redis"
904+
assert res[3] == "7" # (10+3+8)/3
888905

889-
req = aggregations.AggregateRequest("redis").group_by(
890-
"@parent", reducers.stddev("random_num")
891-
)
906+
req = (
907+
aggregations.AggregateRequest("redis")
908+
.group_by("@parent", reducers.stddev("random_num"))
909+
.dialect(dialect)
910+
)
892911

893-
res = (await modclient.ft().aggregate(req)).rows[0]
894-
assert res[1] == "redis"
895-
assert res[3] == "3.60555127546"
912+
res = (await modclient.ft().aggregate(req)).rows[0]
913+
assert res[1] == "redis"
914+
assert res[3] == "3.60555127546"
896915

897-
req = aggregations.AggregateRequest("redis").group_by(
898-
"@parent", reducers.quantile("@random_num", 0.5)
899-
)
916+
req = (
917+
aggregations.AggregateRequest("redis")
918+
.group_by("@parent", reducers.quantile("@random_num", 0.5))
919+
.dialect(dialect)
920+
)
900921

901-
res = (await modclient.ft().aggregate(req)).rows[0]
902-
assert res[1] == "redis"
903-
assert res[3] == "8" # median of 3,8,10
922+
res = (await modclient.ft().aggregate(req)).rows[0]
923+
assert res[1] == "redis"
924+
assert res[3] == "8" # median of 3,8,10
904925

905-
req = aggregations.AggregateRequest("redis").group_by(
906-
"@parent", reducers.tolist("@title")
907-
)
926+
req = (
927+
aggregations.AggregateRequest("redis")
928+
.group_by("@parent", reducers.tolist("@title"))
929+
.dialect(dialect)
930+
)
908931

909-
res = (await modclient.ft().aggregate(req)).rows[0]
910-
assert res[1] == "redis"
911-
assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"}
932+
res = (await modclient.ft().aggregate(req)).rows[0]
933+
assert res[1] == "redis"
934+
assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"}
912935

913-
req = aggregations.AggregateRequest("redis").group_by(
914-
"@parent", reducers.first_value("@title").alias("first")
915-
)
936+
req = (
937+
aggregations.AggregateRequest("redis")
938+
.group_by("@parent", reducers.first_value("@title").alias("first"))
939+
.dialect(dialect)
940+
)
916941

917-
res = (await modclient.ft().aggregate(req)).rows[0]
918-
assert res == ["parent", "redis", "first", "RediSearch"]
942+
res = (await modclient.ft().aggregate(req)).rows[0]
943+
assert res == ["parent", "redis", "first", "RediSearch"]
919944

920-
req = aggregations.AggregateRequest("redis").group_by(
921-
"@parent", reducers.random_sample("@title", 2).alias("random")
922-
)
945+
req = (
946+
aggregations.AggregateRequest("redis")
947+
.group_by("@parent", reducers.random_sample("@title", 2).alias("random"))
948+
.dialect(dialect)
949+
)
923950

924-
res = (await modclient.ft().aggregate(req)).rows[0]
925-
assert res[1] == "redis"
926-
assert res[2] == "random"
927-
assert len(res[3]) == 2
928-
assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"]
951+
res = (await modclient.ft().aggregate(req)).rows[0]
952+
assert res[1] == "redis"
953+
assert res[2] == "random"
954+
assert len(res[3]) == 2
955+
assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"]
929956

930957

931958
@pytest.mark.redismod

tests/test_search.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -973,16 +973,26 @@ def test_aggregations_filter(client):
973973
client.ft().client.hset("doc1", mapping={"name": "bar", "age": "25"})
974974
client.ft().client.hset("doc2", mapping={"name": "foo", "age": "19"})
975975

976-
req = aggregations.AggregateRequest("*").filter("@name=='foo' && @age < 20")
977-
res = client.ft().aggregate(req)
978-
assert len(res.rows) == 1
979-
assert res.rows[0] == ["name", "foo", "age", "19"]
980-
981-
req = aggregations.AggregateRequest("*").filter("@age > 15").sort_by("@age")
982-
res = client.ft().aggregate(req)
983-
assert len(res.rows) == 2
984-
assert res.rows[0] == ["age", "19"]
985-
assert res.rows[1] == ["age", "25"]
976+
for dialect in [1, 2]:
977+
req = (
978+
aggregations.AggregateRequest("*")
979+
.filter("@name=='foo' && @age < 20")
980+
.dialect(dialect)
981+
)
982+
res = client.ft().aggregate(req)
983+
assert len(res.rows) == 1
984+
assert res.rows[0] == ["name", "foo", "age", "19"]
985+
986+
req = (
987+
aggregations.AggregateRequest("*")
988+
.filter("@age > 15")
989+
.sort_by("@age")
990+
.dialect(dialect)
991+
)
992+
res = client.ft().aggregate(req)
993+
assert len(res.rows) == 2
994+
assert res.rows[0] == ["age", "19"]
995+
assert res.rows[1] == ["age", "25"]
986996

987997

988998
@pytest.mark.redismod

0 commit comments

Comments
 (0)