Skip to content

Commit 06872c8

Browse files
committed
feat: application flow
1 parent 6bbb181 commit 06872c8

File tree

122 files changed

+8221
-58
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

122 files changed

+8221
-58
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# coding=utf-8
2+
"""
3+
@project: maxkb
4+
@Author:虎
5+
@file: I_base_chat_pipeline.py
6+
@date:2024/1/9 17:25
7+
@desc:
8+
"""
9+
import time
10+
from abc import abstractmethod
11+
from typing import Type
12+
13+
from rest_framework import serializers
14+
15+
from dataset.models import Paragraph
16+
17+
18+
class ParagraphPipelineModel:
19+
20+
def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str,
21+
is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str,
22+
hit_handling_method: str, directly_return_similarity: float, meta: dict = None):
23+
self.id = _id
24+
self.document_id = document_id
25+
self.dataset_id = dataset_id
26+
self.content = content
27+
self.title = title
28+
self.status = status,
29+
self.is_active = is_active
30+
self.comprehensive_score = comprehensive_score
31+
self.similarity = similarity
32+
self.dataset_name = dataset_name
33+
self.document_name = document_name
34+
self.hit_handling_method = hit_handling_method
35+
self.directly_return_similarity = directly_return_similarity
36+
self.meta = meta
37+
38+
def to_dict(self):
39+
return {
40+
'id': self.id,
41+
'document_id': self.document_id,
42+
'dataset_id': self.dataset_id,
43+
'content': self.content,
44+
'title': self.title,
45+
'status': self.status,
46+
'is_active': self.is_active,
47+
'comprehensive_score': self.comprehensive_score,
48+
'similarity': self.similarity,
49+
'dataset_name': self.dataset_name,
50+
'document_name': self.document_name,
51+
'meta': self.meta,
52+
}
53+
54+
class builder:
55+
def __init__(self):
56+
self.similarity = None
57+
self.paragraph = {}
58+
self.comprehensive_score = None
59+
self.document_name = None
60+
self.dataset_name = None
61+
self.hit_handling_method = None
62+
self.directly_return_similarity = 0.9
63+
self.meta = {}
64+
65+
def add_paragraph(self, paragraph):
66+
if isinstance(paragraph, Paragraph):
67+
self.paragraph = {'id': paragraph.id,
68+
'document_id': paragraph.document_id,
69+
'dataset_id': paragraph.dataset_id,
70+
'content': paragraph.content,
71+
'title': paragraph.title,
72+
'status': paragraph.status,
73+
'is_active': paragraph.is_active,
74+
}
75+
else:
76+
self.paragraph = paragraph
77+
return self
78+
79+
def add_dataset_name(self, dataset_name):
80+
self.dataset_name = dataset_name
81+
return self
82+
83+
def add_document_name(self, document_name):
84+
self.document_name = document_name
85+
return self
86+
87+
def add_hit_handling_method(self, hit_handling_method):
88+
self.hit_handling_method = hit_handling_method
89+
return self
90+
91+
def add_directly_return_similarity(self, directly_return_similarity):
92+
self.directly_return_similarity = directly_return_similarity
93+
return self
94+
95+
def add_comprehensive_score(self, comprehensive_score: float):
96+
self.comprehensive_score = comprehensive_score
97+
return self
98+
99+
def add_similarity(self, similarity: float):
100+
self.similarity = similarity
101+
return self
102+
103+
def add_meta(self, meta: dict):
104+
self.meta = meta
105+
return self
106+
107+
def build(self):
108+
return ParagraphPipelineModel(str(self.paragraph.get('id')), str(self.paragraph.get('document_id')),
109+
str(self.paragraph.get('dataset_id')),
110+
self.paragraph.get('content'), self.paragraph.get('title'),
111+
self.paragraph.get('status'),
112+
self.paragraph.get('is_active'),
113+
self.comprehensive_score, self.similarity, self.dataset_name,
114+
self.document_name, self.hit_handling_method, self.directly_return_similarity,
115+
self.meta)
116+
117+
118+
class IBaseChatPipelineStep:
119+
def __init__(self):
120+
# 当前步骤上下文,用于存储当前步骤信息
121+
self.context = {}
122+
123+
@abstractmethod
124+
def get_step_serializer(self, manage) -> Type[serializers.Serializer]:
125+
pass
126+
127+
def valid_args(self, manage):
128+
step_serializer_clazz = self.get_step_serializer(manage)
129+
step_serializer = step_serializer_clazz(data=manage.context)
130+
step_serializer.is_valid(raise_exception=True)
131+
self.context['step_args'] = step_serializer.data
132+
133+
def run(self, manage):
134+
"""
135+
136+
:param manage: 步骤管理器
137+
:return: 执行结果
138+
"""
139+
start_time = time.time()
140+
self.context['start_time'] = start_time
141+
# 校验参数,
142+
self.valid_args(manage)
143+
self._run(manage)
144+
self.context['run_time'] = time.time() - start_time
145+
146+
def _run(self, manage):
147+
pass
148+
149+
def execute(self, **kwargs):
150+
pass
151+
152+
def get_details(self, manage, **kwargs):
153+
"""
154+
运行详情
155+
:return: 步骤详情
156+
"""
157+
return None
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# coding=utf-8
2+
"""
3+
@project: maxkb
4+
@Author:虎
5+
@file: __init__.py.py
6+
@date:2024/1/9 17:23
7+
@desc:
8+
"""
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# coding=utf-8
2+
"""
3+
@project: maxkb
4+
@Author:虎
5+
@file: pipeline_manage.py
6+
@date:2024/1/9 17:40
7+
@desc:
8+
"""
9+
import time
10+
from functools import reduce
11+
from typing import List, Type, Dict
12+
13+
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
14+
from common.handle.base_to_response import BaseToResponse
15+
from common.handle.impl.response.system_to_response import SystemToResponse
16+
17+
18+
class PipelineManage:
19+
def __init__(self, step_list: List[Type[IBaseChatPipelineStep]],
20+
base_to_response: BaseToResponse = SystemToResponse()):
21+
# 步骤执行器
22+
self.step_list = [step() for step in step_list]
23+
# 上下文
24+
self.context = {'message_tokens': 0, 'answer_tokens': 0}
25+
self.base_to_response = base_to_response
26+
27+
def run(self, context: Dict = None):
28+
self.context['start_time'] = time.time()
29+
if context is not None:
30+
for key, value in context.items():
31+
self.context[key] = value
32+
for step in self.step_list:
33+
step.run(self)
34+
35+
def get_details(self):
36+
return reduce(lambda x, y: {**x, **y}, [{item.get('step_type'): item} for item in
37+
filter(lambda r: r is not None,
38+
[row.get_details(self) for row in self.step_list])], {})
39+
40+
def get_base_to_response(self):
41+
return self.base_to_response
42+
43+
class builder:
44+
def __init__(self):
45+
self.step_list: List[Type[IBaseChatPipelineStep]] = []
46+
self.base_to_response = SystemToResponse()
47+
48+
def append_step(self, step: Type[IBaseChatPipelineStep]):
49+
self.step_list.append(step)
50+
return self
51+
52+
def add_base_to_response(self, base_to_response: BaseToResponse):
53+
self.base_to_response = base_to_response
54+
return self
55+
56+
def build(self):
57+
return PipelineManage(step_list=self.step_list, base_to_response=self.base_to_response)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# coding=utf-8
2+
"""
3+
@project: maxkb
4+
@Author:虎
5+
@file: __init__.py.py
6+
@date:2024/1/9 18:23
7+
@desc:
8+
"""
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# coding=utf-8
2+
"""
3+
@project: maxkb
4+
@Author:虎
5+
@file: __init__.py.py
6+
@date:2024/1/9 18:23
7+
@desc:
8+
"""
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# coding=utf-8
2+
"""
3+
@project: maxkb
4+
@Author:虎
5+
@file: i_chat_step.py
6+
@date:2024/1/9 18:17
7+
@desc: 对话
8+
"""
9+
from abc import abstractmethod
10+
from typing import Type, List
11+
12+
from django.utils.translation import gettext_lazy as _
13+
from langchain.chat_models.base import BaseChatModel
14+
from langchain.schema import BaseMessage
15+
from rest_framework import serializers
16+
17+
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
18+
from application.chat_pipeline.pipeline_manage import PipelineManage
19+
from application.serializers.application_serializers import NoReferencesSetting
20+
from common.field.common import InstanceField
21+
from common.util.field_message import ErrMessage
22+
23+
24+
class ModelField(serializers.Field):
25+
def to_internal_value(self, data):
26+
if not isinstance(data, BaseChatModel):
27+
self.fail(_('Model type error'), value=data)
28+
return data
29+
30+
def to_representation(self, value):
31+
return value
32+
33+
34+
class MessageField(serializers.Field):
35+
def to_internal_value(self, data):
36+
if not isinstance(data, BaseMessage):
37+
self.fail(_('Message type error'), value=data)
38+
return data
39+
40+
def to_representation(self, value):
41+
return value
42+
43+
44+
class PostResponseHandler:
45+
@abstractmethod
46+
def handler(self, chat_id, chat_record_id, paragraph_list: List[ParagraphPipelineModel], problem_text: str,
47+
answer_text,
48+
manage, step, padding_problem_text: str = None, client_id=None, **kwargs):
49+
pass
50+
51+
52+
class IChatStep(IBaseChatPipelineStep):
53+
class InstanceSerializer(serializers.Serializer):
54+
# 对话列表
55+
message_list = serializers.ListField(required=True, child=MessageField(required=True),
56+
error_messages=ErrMessage.list(_("Conversation list")))
57+
model_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid(_("Model id")))
58+
# 段落列表
59+
paragraph_list = serializers.ListField(error_messages=ErrMessage.list(_("Paragraph List")))
60+
# 对话id
61+
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("Conversation ID")))
62+
# 用户问题
63+
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.uuid(_("User Questions")))
64+
# 后置处理器
65+
post_response_handler = InstanceField(model_type=PostResponseHandler,
66+
error_messages=ErrMessage.base(_("Post-processor")))
67+
# 补全问题
68+
padding_problem_text = serializers.CharField(required=False,
69+
error_messages=ErrMessage.base(_("Completion Question")))
70+
# 是否使用流的形式输出
71+
stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base(_("Streaming Output")))
72+
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client id")))
73+
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client Type")))
74+
# 未查询到引用分段
75+
no_references_setting = NoReferencesSetting(required=True,
76+
error_messages=ErrMessage.base(_("No reference segment settings")))
77+
78+
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID")))
79+
80+
model_setting = serializers.DictField(required=True, allow_null=True,
81+
error_messages=ErrMessage.dict(_("Model settings")))
82+
83+
model_params_setting = serializers.DictField(required=False, allow_null=True,
84+
error_messages=ErrMessage.dict(_("Model parameter settings")))
85+
86+
def is_valid(self, *, raise_exception=False):
87+
super().is_valid(raise_exception=True)
88+
message_list: List = self.initial_data.get('message_list')
89+
for message in message_list:
90+
if not isinstance(message, BaseMessage):
91+
raise Exception(_("message type error"))
92+
93+
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
94+
return self.InstanceSerializer
95+
96+
def _run(self, manage: PipelineManage):
97+
chat_result = self.execute(**self.context['step_args'], manage=manage)
98+
manage.context['chat_result'] = chat_result
99+
100+
@abstractmethod
101+
def execute(self, message_list: List[BaseMessage],
102+
chat_id, problem_text,
103+
post_response_handler: PostResponseHandler,
104+
model_id: str = None,
105+
user_id: str = None,
106+
paragraph_list=None,
107+
manage: PipelineManage = None,
108+
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None,
109+
no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs):
110+
pass

0 commit comments

Comments
 (0)