Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 24 additions & 82 deletions apps/application/flow/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import queue
import re
import threading
from functools import reduce
from typing import Iterator
from maxkb.const import CONFIG
from django.http import StreamingHttpResponse
Expand Down Expand Up @@ -435,8 +436,8 @@ async def anext_async(agen):

target_source_node_mapping = {
'TOOL': {'tool-lib-node': lambda n: [n.get('properties').get('node_data').get('tool_lib_id')],
'ai-chat-node': lambda n: [...([n.get('properties').get('node_data').get('mcp_tool_ids')] or []),
...([n.get('properties').get('node_data').get('tool_ids')] or [])]},
'ai-chat-node': lambda n: [*([n.get('properties').get('node_data').get('mcp_tool_ids')] or []),
*([n.get('properties').get('node_data').get('tool_ids')] or [])]},
'MODEL': {'ai-chat-node': lambda n: [n.get('properties').get('node_data').get('model_id')],
'question-node': lambda n: [n.get('properties').get('node_data').get('model_id')],
'speech-to-text-node': lambda n: [n.get('properties').get('node_data').get('stt_model_id')],
Expand Down Expand Up @@ -488,14 +489,28 @@ def get_workflow_resource(workflow, node_handle):
return []


def get_instance_resource(instance, source_type, source_id, target_type, field_call_list):
application_instance_field_call_dict = {
'TOOL': [lambda instance: instance.mcp_tool_ids or [], lambda instance: instance.tool_ids or []],
'MODEL': [lambda instance: [instance.model_id] if instance.model_id else [],
lambda instance: [instance.tts_model_id] if instance.tts_model_id else [],
lambda instance: [instance.stt_model_id] if instance.stt_model_id else []]
}
knowledge_instance_field_call_dict = {
'MODEL': [lambda instance: [instance.model_id] if instance.model_id else [],
lambda instance: [instance.tts_model_id] if instance.tts_model_id else [],
lambda instance: [instance.stt_model_id] if instance.stt_model_id else []],
}


def get_instance_resource(instance, source_type, source_id, instance_field_call_dict):
response = []
from system_manage.models.resource_mapping import ResourceMapping
for field_call in field_call_list:
target_id = field_call(instance)
if target_id:
response.append(ResourceMapping(source_type=source_type, target_type=target_type, source_id=source_id,
target_id=target_id))
for target_type, call_list in instance_field_call_dict.items():
target_id_list = reduce(lambda x, y: [*x, *y], [call(instance) for call in call_list], [])
if target_id_list:
for target_id in target_id_list:
response.append(ResourceMapping(source_type=source_type, target_type=target_type, source_id=source_id,
target_id=target_id))
return response


Expand All @@ -508,85 +523,12 @@ def save_workflow_mapping(workflow, source_type, source_id, other_resource_mappi
resource_mapping_list = get_workflow_resource(workflow,
get_node_handle_callback(source_type,
source_id))
resource_mapping_list += other_resource_mapping
if resource_mapping_list:
resource_mapping_list += other_resource_mapping
QuerySet(ResourceMapping).bulk_create(
{(str(item.target_type) + str(item.target_id)): item for item in resource_mapping_list}.values())


def save_simple_mapping(application, source_type, source_id):
"""
保存应用资源映射关系

Args:
application: 应用对象
source_type: 源类型
source_id: 源ID
"""
from system_manage.models.resource_mapping import ResourceMapping
from django.db.models import QuerySet
from application.models import ApplicationKnowledgeMapping # 假设模型在此处定义
from system_manage.models.resource_mapping import ResourceType
# 删除原有映射关系
QuerySet(ResourceMapping).filter(source_type=source_type, source_id=source_id).delete()

# 构建资源映射列表
resource_mapping_list = []

# 定义模型ID字段映射
model_fields = ['model_id', 'tts_model_id', 'stt_model_id']
for field in model_fields:
model_id = getattr(application, field, None)
if model_id:
resource_mapping_list.append(ResourceMapping(
source_type=source_type,
target_type=ResourceType.MODEL,
source_id=source_id,
target_id=model_id
))

# 定义工具ID字段映射
tool_fields = ['mcp_tool_ids', 'tool_ids']
for field in tool_fields:
tool_ids = getattr(application, field, []) or []
resource_mapping_list.extend([
ResourceMapping(
source_type=source_type,
target_type=ResourceType.TOOL,
source_id=source_id,
target_id=tool_id
) for tool_id in tool_ids if tool_id
])

# 处理知识库映射
knowledge_mappings = ApplicationKnowledgeMapping.objects.filter(
application_id=application.id
)
resource_mapping_list.extend([
ResourceMapping(
source_type=source_type,
target_type=ResourceType.KNOWLEDGE,
source_id=source_id,
target_id=km.knowledge_id
) for km in knowledge_mappings
])

# 处理应用ID映射
application_ids = getattr(application, 'application_ids', []) or []
resource_mapping_list.extend([
ResourceMapping(
source_type=source_type,
target_type=ResourceType.APPLICATION,
source_id=source_id,
target_id=app_id
) for app_id in application_ids if app_id
])

# 批量创建资源映射
if resource_mapping_list:
QuerySet(ResourceMapping).bulk_create(resource_mapping_list)


def get_tool_id_list(workflow):
_result = []
for node in workflow.get('nodes', []):
Expand Down
88 changes: 72 additions & 16 deletions apps/application/serializers/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from rest_framework.utils.formatting import lazy_format

from application.flow.common import Workflow
from application.models.application import Application, ApplicationTypeChoices, ApplicationKnowledgeMapping, \
from application.models.application import Application, ApplicationTypeChoices, \
ApplicationFolder, ApplicationVersion
from application.models.application_access_token import ApplicationAccessToken
from application.serializers.common import update_resource_mapping_by_application
Expand Down Expand Up @@ -541,7 +541,9 @@ def insert_workflow(self, instance: Dict):

@staticmethod
def to_application_knowledge_mapping(application_id: str, knowledge_id: str):
return ApplicationKnowledgeMapping(id=uuid.uuid7(), application_id=application_id, knowledge_id=knowledge_id)
return ResourceMapping(id=uuid.uuid7(), source_id=application_id, target_id=knowledge_id,
source_type="APPLICATION",
target_type="KNOWLEDGE")

def insert_simple(self, instance: Dict):
self.is_valid(raise_exception=True)
Expand All @@ -560,7 +562,7 @@ def insert_simple(self, instance: Dict):
ApplicationAccessToken(application_id=application_model.id,
access_token=hashlib.md5(str(uuid.uuid7()).encode()).hexdigest()[8:24]).save()
# 插入关联数据
QuerySet(ApplicationKnowledgeMapping).bulk_create(application_knowledge_mapping_model_list)
QuerySet(ResourceMapping).bulk_create(application_knowledge_mapping_model_list)
return ApplicationCreateSerializer.ApplicationResponse(application_model).data

@transaction.atomic
Expand Down Expand Up @@ -785,7 +787,6 @@ def delete(self, with_valid=True):
self.is_valid()
application_id = self.data.get('application_id')
QuerySet(ApplicationVersion).filter(application_id=application_id).delete()
QuerySet(ApplicationKnowledgeMapping).filter(application_id=application_id).delete()
QuerySet(ResourceMapping).filter(
Q(target_id=application_id) | Q(source_id=application_id)
).delete()
Expand Down Expand Up @@ -884,7 +885,6 @@ def publish(self, instance, with_valid=True):
application_access_token.save()
else:
access_token = application_access_token.access_token
update_resource_mapping_by_application(self.data.get("application_id"))
del_application_access_token(access_token)
return self.one(with_valid=False)

Expand Down Expand Up @@ -921,6 +921,19 @@ def update_work_flow_model(instance):
if 'name' in node_data:
instance['name'] = node_data['name']
break
knowledge_node_list = ApplicationOperateSerializer.get_search_node(instance.get('work_flow'))
for knowledge_node in knowledge_node_list:
node_data = knowledge_node.get('properties').get('node_data')
# 全部知识库id
all_knowledge_id_list = node_data.get('all_knowledge_id_list') or []
# 用户修改的知识库id
knowledge_id_list = node_data.get('knowledge_id_list') or []
# 用户可以看到的知识库
knowledge_list = node_data.get('knowledge_list') or []
view_knowledge_id_list = [knowledge.get('id') for knowledge in knowledge_list]
other_knowledge_id_list = [knowledge_id for knowledge_id in all_knowledge_id_list if
not view_knowledge_id_list.__contains__(knowledge_id)]
node_data['knowledge_id_list'] = other_knowledge_id_list + knowledge_id_list

@transaction.atomic
def edit(self, instance: Dict, with_valid=True):
Expand Down Expand Up @@ -972,19 +985,25 @@ def edit(self, instance: Dict, with_valid=True):
if update_key in instance and instance.get(update_key) is not None:
application.__setattr__(update_key, instance.get(update_key))
application.save()

# 当前用户可修改关联的知识库列表
application_knowledge_id_list = [str(knowledge.get('id')) for knowledge in
self.list_knowledge(with_valid=False)]
knowledge_id_list = []
if 'knowledge_id_list' in instance:
knowledge_id_list = instance.get('knowledge_id_list')
# 当前用户可修改关联的知识库列表
application_knowledge_id_list = [str(knowledge.get('id')) for knowledge in
self.list_knowledge(with_valid=False)]
knowledge_id_list = instance.get('knowledge_id_list')
for knowledge_id in knowledge_id_list:
if not application_knowledge_id_list.__contains__(knowledge_id):
message = lazy_format(_('Unknown knowledge base id {dataset_id}, unable to associate'),
dataset_id=knowledge_id)
raise AppApiException(500, str(message))

self.save_application_knowledge_mapping(application_knowledge_id_list, knowledge_id_list, application_id)
update_resource_mapping_by_application(application_id,
self.get_application_knowledge_mapping(application_knowledge_id_list,
knowledge_id_list,
application_id))
return self.one(with_valid=False)

def update_template_workflow(self, instance: Dict, app: Application):
Expand Down Expand Up @@ -1074,9 +1093,11 @@ def one(self, with_valid=True):
knowledge_list = []
knowledge_id_list = []
if application.type == 'SIMPLE':
mapping_knowledge_list = QuerySet(ApplicationKnowledgeMapping).filter(application_id=application_id)
knowledge_list = [available_knowledge_dict.get(str(km.knowledge_id)) for km in mapping_knowledge_list if
available_knowledge_dict.__contains__(str(km.knowledge_id))]
mapping_knowledge_list = QuerySet(ResourceMapping).filter(source_id=application_id,
source_type="APPLICATION",
target_type="KNOWLEDGE")
knowledge_list = [available_knowledge_dict.get(str(km.target_id)) for km in mapping_knowledge_list if
available_knowledge_dict.__contains__(str(km.target_id))]
knowledge_id_list = [k.get('id') for k in knowledge_list]
else:
self.update_knowledge_node(application.work_flow, available_knowledge_dict)
Expand All @@ -1089,7 +1110,17 @@ def one(self, with_valid=True):
def get_search_node(work_flow):
if work_flow is None:
return []
return [node for node in work_flow.get('nodes', []) if node.get('type', '') == 'search-knowledge-node']
response = []
if 'nodes' in work_flow:
for node in work_flow.get('nodes'):
if node.get('type', '') == 'search-knowledge-node':
response.append(node)
if node.get('type') == 'loop-node':
r = ApplicationOperateSerializer.get_search_node(
node.get('properties', {}).get('node_data', {}).get('loop_body'))
for rn in r:
response.append(rn)
return response

def update_knowledge_node(self, workflow, available_knowledge_dict):
"""
Expand Down Expand Up @@ -1143,14 +1174,39 @@ def list_knowledge(self, with_valid=True):
def save_application_knowledge_mapping(application_knowledge_id_list, knowledge_id_list, application_id):
# 需要排除已删除的数据集
knowledge_id_list = [knowledge.id for knowledge in QuerySet(Knowledge).filter(id__in=knowledge_id_list)]

# 删除已经关联的id
QuerySet(ApplicationKnowledgeMapping).filter(knowledge_id__in=application_knowledge_id_list,
application_id=application_id).delete()
QuerySet(ResourceMapping).filter(target_id__in=application_knowledge_id_list,
source_id=application_id,
source_type='APPLICATION',
target_type="KNOWLEDGE").delete()
# 插入
QuerySet(ApplicationKnowledgeMapping).bulk_create(
[ApplicationKnowledgeMapping(application_id=application_id, knowledge_id=knowledge_id) for knowledge_id in
QuerySet(ResourceMapping).bulk_create(
[ResourceMapping(source_id=application_id, target_id=knowledge_id, source_type='APPLICATION',
target_type="KNOWLEDGE") for knowledge_id in
knowledge_id_list]) if len(knowledge_id_list) > 0 else None

@staticmethod
def get_application_knowledge_mapping(application_knowledge_id_list, knowledge_id_list, application_id):
"""

@param application_knowledge_id_list: 当前应用可修改的知识库列表
@param knowledge_id_list: 用户修改的知识库列表
@param application_id: 应用id
@return:
"""
# 当前知识库和应用已关联列表
knowledge_application_mapping_list = QuerySet(ResourceMapping).filter(source_id=application_id,
source_type='APPLICATION',
target_type="KNOWLEDGE",
).exclude(
target_id__in=application_knowledge_id_list)
edit_knowledge_list = [ResourceMapping(source_id=application_id, target_id=knowledge_id,
source_type='APPLICATION',
target_type="KNOWLEDGE")
for knowledge_id in knowledge_id_list]
return list(knowledge_application_mapping_list) + edit_knowledge_list

def speech_to_text(self, instance, debug=True, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
Expand Down
Loading