dify
This commit is contained in:
59
dify/api/controllers/service_api/__init__.py
Normal file
59
dify/api/controllers/service_api/__init__.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from flask import Blueprint
|
||||
from flask_restx import Namespace
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint("service_api", __name__, url_prefix="/v1")
|
||||
|
||||
api = ExternalApi(
|
||||
bp,
|
||||
version="1.0",
|
||||
title="Service API",
|
||||
description="API for application services",
|
||||
)
|
||||
|
||||
service_api_ns = Namespace("service_api", description="Service operations", path="/")
|
||||
|
||||
from . import index
|
||||
from .app import (
|
||||
annotation,
|
||||
app,
|
||||
audio,
|
||||
completion,
|
||||
conversation,
|
||||
file,
|
||||
file_preview,
|
||||
message,
|
||||
site,
|
||||
workflow,
|
||||
)
|
||||
from .dataset import (
|
||||
dataset,
|
||||
document,
|
||||
hit_testing,
|
||||
metadata,
|
||||
segment,
|
||||
)
|
||||
from .workspace import models
|
||||
|
||||
__all__ = [
|
||||
"annotation",
|
||||
"app",
|
||||
"audio",
|
||||
"completion",
|
||||
"conversation",
|
||||
"dataset",
|
||||
"document",
|
||||
"file",
|
||||
"file_preview",
|
||||
"hit_testing",
|
||||
"index",
|
||||
"message",
|
||||
"metadata",
|
||||
"models",
|
||||
"segment",
|
||||
"site",
|
||||
"workflow",
|
||||
]
|
||||
|
||||
api.add_namespace(service_api_ns)
|
||||
0
dify/api/controllers/service_api/app/__init__.py
Normal file
0
dify/api/controllers/service_api/app/__init__.py
Normal file
186
dify/api/controllers/service_api/app/annotation.py
Normal file
186
dify/api/controllers/service_api/app/annotation.py
Normal file
@@ -0,0 +1,186 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Api, Namespace, Resource, fields, reqparse
|
||||
from flask_restx.api import HTTPStatus
|
||||
|
||||
from controllers.console.wraps import edit_permission_required
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import annotation_fields, build_annotation_model
|
||||
from models.model import App
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Define parsers for annotation API
|
||||
annotation_create_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("question", required=True, type=str, location="json", help="Annotation question")
|
||||
.add_argument("answer", required=True, type=str, location="json", help="Annotation answer")
|
||||
)
|
||||
|
||||
annotation_reply_action_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching"
|
||||
)
|
||||
.add_argument("embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name")
|
||||
.add_argument("embedding_model_name", required=True, type=str, location="json", help="Embedding model name")
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/apps/annotation-reply/<string:action>")
|
||||
class AnnotationReplyActionApi(Resource):
|
||||
@service_api_ns.expect(annotation_reply_action_parser)
|
||||
@service_api_ns.doc("annotation_reply_action")
|
||||
@service_api_ns.doc(description="Enable or disable annotation reply feature")
|
||||
@service_api_ns.doc(params={"action": "Action to perform: 'enable' or 'disable'"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Action completed successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def post(self, app_model: App, action: Literal["enable", "disable"]):
|
||||
"""Enable or disable annotation reply feature."""
|
||||
args = annotation_reply_action_parser.parse_args()
|
||||
if action == "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
|
||||
elif action == "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_model.id)
|
||||
return result, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/apps/annotation-reply/<string:action>/status/<uuid:job_id>")
|
||||
class AnnotationReplyActionStatusApi(Resource):
|
||||
@service_api_ns.doc("get_annotation_reply_action_status")
|
||||
@service_api_ns.doc(description="Get the status of an annotation reply action job")
|
||||
@service_api_ns.doc(params={"action": "Action type", "job_id": "Job ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Job status retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Job not found",
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App, job_id, action):
|
||||
"""Get the status of an annotation reply action job."""
|
||||
job_id = str(job_id)
|
||||
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
|
||||
cache_result = redis_client.get(app_annotation_job_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job does not exist.")
|
||||
|
||||
job_status = cache_result.decode()
|
||||
error_msg = ""
|
||||
if job_status == "error":
|
||||
app_annotation_error_key = f"{action}_app_annotation_error_{str(job_id)}"
|
||||
error_msg = redis_client.get(app_annotation_error_key).decode()
|
||||
|
||||
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
||||
|
||||
|
||||
# Define annotation list response model
|
||||
annotation_list_fields = {
|
||||
"data": fields.List(fields.Nested(annotation_fields)),
|
||||
"has_more": fields.Boolean,
|
||||
"limit": fields.Integer,
|
||||
"total": fields.Integer,
|
||||
"page": fields.Integer,
|
||||
}
|
||||
|
||||
|
||||
def build_annotation_list_model(api_or_ns: Api | Namespace):
|
||||
"""Build the annotation list model for the API or Namespace."""
|
||||
copied_annotation_list_fields = annotation_list_fields.copy()
|
||||
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))
|
||||
return api_or_ns.model("AnnotationList", copied_annotation_list_fields)
|
||||
|
||||
|
||||
@service_api_ns.route("/apps/annotations")
|
||||
class AnnotationListApi(Resource):
|
||||
@service_api_ns.doc("list_annotations")
|
||||
@service_api_ns.doc(description="List annotations for the application")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Annotations retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_annotation_list_model(service_api_ns))
|
||||
def get(self, app_model: App):
|
||||
"""List annotations for the application."""
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
keyword = request.args.get("keyword", default="", type=str)
|
||||
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword)
|
||||
return {
|
||||
"data": annotation_list,
|
||||
"has_more": len(annotation_list) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
|
||||
@service_api_ns.expect(annotation_create_parser)
|
||||
@service_api_ns.doc("create_annotation")
|
||||
@service_api_ns.doc(description="Create a new annotation")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
201: "Annotation created successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED)
|
||||
def post(self, app_model: App):
|
||||
"""Create a new annotation."""
|
||||
args = annotation_create_parser.parse_args()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
|
||||
return annotation, 201
|
||||
|
||||
|
||||
@service_api_ns.route("/apps/annotations/<uuid:annotation_id>")
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@service_api_ns.expect(annotation_create_parser)
|
||||
@service_api_ns.doc("update_annotation")
|
||||
@service_api_ns.doc(description="Update an existing annotation")
|
||||
@service_api_ns.doc(params={"annotation_id": "Annotation ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Annotation updated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - insufficient permissions",
|
||||
404: "Annotation not found",
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@edit_permission_required
|
||||
@service_api_ns.marshal_with(build_annotation_model(service_api_ns))
|
||||
def put(self, app_model: App, annotation_id: str):
|
||||
"""Update an existing annotation."""
|
||||
args = annotation_create_parser.parse_args()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
|
||||
return annotation
|
||||
|
||||
@service_api_ns.doc("delete_annotation")
|
||||
@service_api_ns.doc(description="Delete an annotation")
|
||||
@service_api_ns.doc(params={"annotation_id": "Annotation ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
204: "Annotation deleted successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - insufficient permissions",
|
||||
404: "Annotation not found",
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, annotation_id: str):
|
||||
"""Delete an annotation."""
|
||||
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
|
||||
return {"result": "success"}, 204
|
||||
95
dify/api/controllers/service_api/app/app.py
Normal file
95
dify/api/controllers/service_api/app/app.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.common.fields import build_parameters_model
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from models.model import App, AppMode
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
@service_api_ns.route("/parameters")
|
||||
class AppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
@service_api_ns.doc("get_app_parameters")
|
||||
@service_api_ns.doc(description="Retrieve application input parameters and configuration")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Parameters retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Application not found",
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_parameters_model(service_api_ns))
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app parameters.
|
||||
|
||||
Returns the input form parameters and configuration for the application.
|
||||
"""
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
|
||||
|
||||
@service_api_ns.route("/meta")
|
||||
class AppMetaApi(Resource):
|
||||
@service_api_ns.doc("get_app_meta")
|
||||
@service_api_ns.doc(description="Get application metadata")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Metadata retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Application not found",
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
"""Get app metadata.
|
||||
|
||||
Returns metadata about the application including configuration and settings.
|
||||
"""
|
||||
return AppService().get_app_meta(app_model)
|
||||
|
||||
|
||||
@service_api_ns.route("/info")
|
||||
class AppInfoApi(Resource):
|
||||
@service_api_ns.doc("get_app_info")
|
||||
@service_api_ns.doc(description="Get basic application information")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Application info retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Application not found",
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
"""Get app information.
|
||||
|
||||
Returns basic information about the application including name, description, tags, and mode.
|
||||
"""
|
||||
tags = [tag.name for tag in app_model.tags]
|
||||
return {
|
||||
"name": app_model.name,
|
||||
"description": app_model.description,
|
||||
"tags": tags,
|
||||
"mode": app_model.mode,
|
||||
"author_name": app_model.author_name,
|
||||
}
|
||||
150
dify/api/controllers/service_api/app/audio.py
Normal file
150
dify/api/controllers/service_api/app/audio.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App, EndUser
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@service_api_ns.route("/audio-to-text")
|
||||
class AudioApi(Resource):
|
||||
@service_api_ns.doc("audio_to_text")
|
||||
@service_api_ns.doc(description="Convert audio to text using speech-to-text")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Audio successfully transcribed",
|
||||
400: "Bad request - no audio or invalid audio",
|
||||
401: "Unauthorized - invalid API token",
|
||||
413: "Audio file too large",
|
||||
415: "Unsupported audio type",
|
||||
500: "Internal server error",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""Convert audio to text using speech-to-text.
|
||||
|
||||
Accepts an audio file upload and returns the transcribed text.
|
||||
"""
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user.id)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
# Define parser for text-to-audio API
|
||||
text_to_audio_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=str, required=False, location="json", help="Message ID")
|
||||
.add_argument("voice", type=str, location="json", help="Voice to use for TTS")
|
||||
.add_argument("text", type=str, location="json", help="Text to convert to audio")
|
||||
.add_argument("streaming", type=bool, location="json", help="Enable streaming response")
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/text-to-audio")
|
||||
class TextApi(Resource):
|
||||
@service_api_ns.expect(text_to_audio_parser)
|
||||
@service_api_ns.doc("text_to_audio")
|
||||
@service_api_ns.doc(description="Convert text to audio using text-to-speech")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Text successfully converted to audio",
|
||||
400: "Bad request - invalid parameters",
|
||||
401: "Unauthorized - invalid API token",
|
||||
500: "Internal server error",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""Convert text to audio using text-to-speech.
|
||||
|
||||
Converts the provided text to audio using the specified voice.
|
||||
"""
|
||||
try:
|
||||
args = text_to_audio_parser.parse_args()
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
259
dify/api/controllers/service_api/app/completion.py
Normal file
259
dify/api/controllers/service_api/app/completion.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
NotChatAppError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define parser for completion API
|
||||
completion_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for completion")
|
||||
.add_argument("query", type=str, location="json", default="", help="The query string")
|
||||
.add_argument("files", type=list, required=False, location="json", help="List of file attachments")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode")
|
||||
.add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source")
|
||||
)
|
||||
|
||||
# Define parser for chat API
|
||||
chat_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat")
|
||||
.add_argument("query", type=str, required=True, location="json", help="The chat query")
|
||||
.add_argument("files", type=list, required=False, location="json", help="List of file attachments")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode")
|
||||
.add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID")
|
||||
.add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source")
|
||||
.add_argument(
|
||||
"auto_generate_name",
|
||||
type=bool,
|
||||
required=False,
|
||||
default=True,
|
||||
location="json",
|
||||
help="Auto generate conversation name",
|
||||
)
|
||||
.add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat")
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/completion-messages")
|
||||
class CompletionApi(Resource):
|
||||
@service_api_ns.expect(completion_parser)
|
||||
@service_api_ns.doc("create_completion")
|
||||
@service_api_ns.doc(description="Create a completion for the given prompt")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Completion created successfully",
|
||||
400: "Bad request - invalid parameters",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Conversation not found",
|
||||
500: "Internal server error",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""Create a completion for the given prompt.
|
||||
|
||||
This endpoint generates a completion based on the provided inputs and query.
|
||||
Supports both blocking and streaming response modes.
|
||||
"""
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise AppUnavailableError()
|
||||
|
||||
args = completion_parser.parse_args()
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@service_api_ns.route("/completion-messages/<string:task_id>/stop")
|
||||
class CompletionStopApi(Resource):
|
||||
@service_api_ns.doc("stop_completion")
|
||||
@service_api_ns.doc(description="Stop a running completion task")
|
||||
@service_api_ns.doc(params={"task_id": "The ID of the task to stop"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Task stopped successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Task not found",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
"""Stop a running completion task."""
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise AppUnavailableError()
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
user_id=end_user.id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/chat-messages")
|
||||
class ChatApi(Resource):
|
||||
@service_api_ns.expect(chat_parser)
|
||||
@service_api_ns.doc("create_chat_message")
|
||||
@service_api_ns.doc(description="Send a message in a chat conversation")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Message sent successfully",
|
||||
400: "Bad request - invalid parameters or workflow issues",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Conversation or workflow not found",
|
||||
429: "Rate limit exceeded",
|
||||
500: "Internal server error",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""Send a message in a chat conversation.
|
||||
|
||||
This endpoint handles chat messages for chat, agent chat, and advanced chat applications.
|
||||
Supports conversation management and both blocking and streaming response modes.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
args = chat_parser.parse_args()
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except IsDraftWorkflowError as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except WorkflowIdFormatError as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@service_api_ns.route("/chat-messages/<string:task_id>/stop")
|
||||
class ChatStopApi(Resource):
|
||||
@service_api_ns.doc("stop_chat_message")
|
||||
@service_api_ns.doc(description="Stop a running chat message generation")
|
||||
@service_api_ns.doc(params={"task_id": "The ID of the task to stop"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Task stopped successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Task not found",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
"""Stop a running chat message generation."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
user_id=end_user.id,
|
||||
app_mode=app_mode,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
264
dify/api/controllers/service_api/app/conversation.py
Normal file
264
dify/api/controllers/service_api/app/conversation.py
Normal file
@@ -0,0 +1,264 @@
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx._http import HTTPStatus
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
build_conversation_delete_model,
|
||||
build_conversation_infinite_scroll_pagination_model,
|
||||
build_simple_conversation_model,
|
||||
)
|
||||
from fields.conversation_variable_fields import (
|
||||
build_conversation_variable_infinite_scroll_pagination_model,
|
||||
build_conversation_variable_model,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
# Define parsers for conversation APIs
|
||||
conversation_list_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args", help="Last conversation ID for pagination")
|
||||
.add_argument(
|
||||
"limit",
|
||||
type=int_range(1, 100),
|
||||
required=False,
|
||||
default=20,
|
||||
location="args",
|
||||
help="Number of conversations to return",
|
||||
)
|
||||
.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
required=False,
|
||||
default="-updated_at",
|
||||
location="args",
|
||||
help="Sort order for conversations",
|
||||
)
|
||||
)
|
||||
|
||||
conversation_rename_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=False, location="json", help="New conversation name")
|
||||
.add_argument(
|
||||
"auto_generate",
|
||||
type=bool,
|
||||
required=False,
|
||||
default=False,
|
||||
location="json",
|
||||
help="Auto-generate conversation name",
|
||||
)
|
||||
)
|
||||
|
||||
conversation_variables_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args", help="Last variable ID for pagination")
|
||||
.add_argument(
|
||||
"limit",
|
||||
type=int_range(1, 100),
|
||||
required=False,
|
||||
default=20,
|
||||
location="args",
|
||||
help="Number of variables to return",
|
||||
)
|
||||
)
|
||||
|
||||
conversation_variable_update_parser = reqparse.RequestParser().add_argument(
|
||||
# using lambda is for passing the already-typed value without modification
|
||||
# if no lambda, it will be converted to string
|
||||
# the string cannot be converted using json.loads
|
||||
"value",
|
||||
required=True,
|
||||
location="json",
|
||||
type=lambda x: x,
|
||||
help="New value for the conversation variable",
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/conversations")
|
||||
class ConversationApi(Resource):
|
||||
@service_api_ns.expect(conversation_list_parser)
|
||||
@service_api_ns.doc("list_conversations")
|
||||
@service_api_ns.doc(description="List all conversations for the current user")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Conversations retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Last conversation not found",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@service_api_ns.marshal_with(build_conversation_infinite_scroll_pagination_model(service_api_ns))
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
"""List all conversations for the current user.
|
||||
|
||||
Supports pagination using last_id and limit parameters.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
args = conversation_list_parser.parse_args()
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
return ConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
sort_by=args["sort_by"],
|
||||
)
|
||||
except services.errors.conversation.LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
||||
@service_api_ns.route("/conversations/<uuid:c_id>")
|
||||
class ConversationDetailApi(Resource):
|
||||
@service_api_ns.doc("delete_conversation")
|
||||
@service_api_ns.doc(description="Delete a specific conversation")
|
||||
@service_api_ns.doc(params={"c_id": "Conversation ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
204: "Conversation deleted successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Conversation not found",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=HTTPStatus.NO_CONTENT)
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""Delete a specific conversation."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@service_api_ns.route("/conversations/<uuid:c_id>/name")
|
||||
class ConversationRenameApi(Resource):
|
||||
@service_api_ns.expect(conversation_rename_parser)
|
||||
@service_api_ns.doc("rename_conversation")
|
||||
@service_api_ns.doc(description="Rename a conversation or auto-generate a name")
|
||||
@service_api_ns.doc(params={"c_id": "Conversation ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Conversation renamed successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Conversation not found",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@service_api_ns.marshal_with(build_simple_conversation_model(service_api_ns))
|
||||
def post(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""Rename a conversation or auto-generate a name."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
args = conversation_rename_parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
@service_api_ns.route("/conversations/<uuid:c_id>/variables")
|
||||
class ConversationVariablesApi(Resource):
|
||||
@service_api_ns.expect(conversation_variables_parser)
|
||||
@service_api_ns.doc("list_conversation_variables")
|
||||
@service_api_ns.doc(description="List all variables for a conversation")
|
||||
@service_api_ns.doc(params={"c_id": "Conversation ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Variables retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Conversation not found",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@service_api_ns.marshal_with(build_conversation_variable_infinite_scroll_pagination_model(service_api_ns))
|
||||
def get(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""List all variables for a conversation.
|
||||
|
||||
Conversational variables are only available for chat applications.
|
||||
"""
|
||||
# conversational variable only for chat app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
args = conversation_variables_parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.get_conversational_variable(
|
||||
app_model, conversation_id, end_user, args["limit"], args["last_id"]
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
@service_api_ns.route("/conversations/<uuid:c_id>/variables/<uuid:variable_id>")
|
||||
class ConversationVariableDetailApi(Resource):
|
||||
@service_api_ns.expect(conversation_variable_update_parser)
|
||||
@service_api_ns.doc("update_conversation_variable")
|
||||
@service_api_ns.doc(description="Update a conversation variable's value")
|
||||
@service_api_ns.doc(params={"c_id": "Conversation ID", "variable_id": "Variable ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Variable updated successfully",
|
||||
400: "Bad request - type mismatch",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Conversation or variable not found",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@service_api_ns.marshal_with(build_conversation_variable_model(service_api_ns))
|
||||
def put(self, app_model: App, end_user: EndUser, c_id, variable_id):
|
||||
"""Update a conversation variable's value.
|
||||
|
||||
Allows updating the value of a specific conversation variable.
|
||||
The value must match the variable's expected type.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
variable_id = str(variable_id)
|
||||
|
||||
args = conversation_variable_update_parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.update_conversation_variable(
|
||||
app_model, conversation_id, variable_id, end_user, args["value"]
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationVariableNotExistsError:
|
||||
raise NotFound("Conversation Variable Not Exists.")
|
||||
except services.errors.conversation.ConversationVariableTypeMismatchError as e:
|
||||
raise BadRequest(str(e))
|
||||
97
dify/api/controllers/service_api/app/error.py
Normal file
97
dify/api/controllers/service_api/app/error.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class AppUnavailableError(BaseHTTPException):
|
||||
error_code = "app_unavailable"
|
||||
description = "App unavailable, please check your app configurations."
|
||||
code = 400
|
||||
|
||||
|
||||
class NotCompletionAppError(BaseHTTPException):
|
||||
error_code = "not_completion_app"
|
||||
description = "Please check if your Completion app mode matches the right API route."
|
||||
code = 400
|
||||
|
||||
|
||||
class NotChatAppError(BaseHTTPException):
|
||||
error_code = "not_chat_app"
|
||||
description = "Please check if your app mode matches the right API route."
|
||||
code = 400
|
||||
|
||||
|
||||
class NotWorkflowAppError(BaseHTTPException):
|
||||
error_code = "not_workflow_app"
|
||||
description = "Please check if your app mode matches the right API route."
|
||||
code = 400
|
||||
|
||||
|
||||
class ConversationCompletedError(BaseHTTPException):
|
||||
error_code = "conversation_completed"
|
||||
description = "The conversation has ended. Please start a new conversation."
|
||||
code = 400
|
||||
|
||||
|
||||
class ProviderNotInitializeError(BaseHTTPException):
|
||||
error_code = "provider_not_initialize"
|
||||
description = (
|
||||
"No valid model provider credentials found. "
|
||||
"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||
)
|
||||
code = 400
|
||||
|
||||
|
||||
class ProviderQuotaExceededError(BaseHTTPException):
|
||||
error_code = "provider_quota_exceeded"
|
||||
description = (
|
||||
"Your quota for Dify Hosted OpenAI has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials."
|
||||
)
|
||||
code = 400
|
||||
|
||||
|
||||
class ProviderModelCurrentlyNotSupportError(BaseHTTPException):
|
||||
error_code = "model_currently_not_support"
|
||||
description = "Dify Hosted OpenAI trial currently not support the GPT-4 model."
|
||||
code = 400
|
||||
|
||||
|
||||
class CompletionRequestError(BaseHTTPException):
|
||||
error_code = "completion_request_error"
|
||||
description = "Completion request failed."
|
||||
code = 400
|
||||
|
||||
|
||||
class NoAudioUploadedError(BaseHTTPException):
|
||||
error_code = "no_audio_uploaded"
|
||||
description = "Please upload your audio."
|
||||
code = 400
|
||||
|
||||
|
||||
class AudioTooLargeError(BaseHTTPException):
|
||||
error_code = "audio_too_large"
|
||||
description = "Audio size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedAudioTypeError(BaseHTTPException):
|
||||
error_code = "unsupported_audio_type"
|
||||
description = "Audio type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||
error_code = "provider_not_support_speech_to_text"
|
||||
description = "Provider not support speech to text."
|
||||
code = 400
|
||||
|
||||
|
||||
class FileNotFoundError(BaseHTTPException):
|
||||
error_code = "file_not_found"
|
||||
description = "The requested file was not found."
|
||||
code = 404
|
||||
|
||||
|
||||
class FileAccessDeniedError(BaseHTTPException):
|
||||
error_code = "file_access_denied"
|
||||
description = "Access to the requested file is denied."
|
||||
code = 403
|
||||
67
dify/api/controllers/service_api/app/file.py
Normal file
67
dify/api/controllers/service_api/app/file.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from flask_restx.api import HTTPStatus
|
||||
|
||||
import services
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import build_file_model
|
||||
from models import App, EndUser
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@service_api_ns.route("/files/upload")
|
||||
class FileApi(Resource):
|
||||
@service_api_ns.doc("upload_file")
|
||||
@service_api_ns.doc(description="Upload a file for use in conversations")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
201: "File uploaded successfully",
|
||||
400: "Bad request - no file or invalid file",
|
||||
401: "Unauthorized - invalid API token",
|
||||
413: "File too large",
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
||||
@service_api_ns.marshal_with(build_file_model(service_api_ns), code=HTTPStatus.CREATED)
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""Upload a file for use in conversations.
|
||||
|
||||
Accepts a single file upload via multipart/form-data.
|
||||
"""
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
file = request.files["file"]
|
||||
if not file.mimetype:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=end_user,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return upload_file, 201
|
||||
186
dify/api/controllers/service_api/app/file_preview.py
Normal file
186
dify/api/controllers/service_api/app/file_preview.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import logging
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
FileAccessDeniedError,
|
||||
FileNotFoundError,
|
||||
)
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import App, EndUser, Message, MessageFile, UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define parser for file preview API
|
||||
file_preview_parser = reqparse.RequestParser().add_argument(
|
||||
"as_attachment", type=bool, required=False, default=False, location="args", help="Download as attachment"
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/files/<uuid:file_id>/preview")
|
||||
class FilePreviewApi(Resource):
|
||||
"""
|
||||
Service API File Preview endpoint
|
||||
|
||||
Provides secure file preview/download functionality for external API users.
|
||||
Files can only be accessed if they belong to messages within the requesting app's context.
|
||||
"""
|
||||
|
||||
@service_api_ns.expect(file_preview_parser)
|
||||
@service_api_ns.doc("preview_file")
|
||||
@service_api_ns.doc(description="Preview or download a file uploaded via Service API")
|
||||
@service_api_ns.doc(params={"file_id": "UUID of the file to preview"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "File retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - file access denied",
|
||||
404: "File not found",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
def get(self, app_model: App, end_user: EndUser, file_id: str):
|
||||
"""
|
||||
Preview/Download a file that was uploaded via Service API.
|
||||
|
||||
Provides secure file preview/download functionality.
|
||||
Files can only be accessed if they belong to messages within the requesting app's context.
|
||||
"""
|
||||
file_id = str(file_id)
|
||||
|
||||
# Parse query parameters
|
||||
args = file_preview_parser.parse_args()
|
||||
|
||||
# Validate file ownership and get file objects
|
||||
_, upload_file = self._validate_file_ownership(file_id, app_model.id)
|
||||
|
||||
# Get file content generator
|
||||
try:
|
||||
generator = storage.load(upload_file.key, stream=True)
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(f"Failed to load file content: {str(e)}")
|
||||
|
||||
# Build response with appropriate headers
|
||||
response = self._build_file_response(generator, upload_file, args["as_attachment"])
|
||||
|
||||
return response
|
||||
|
||||
def _validate_file_ownership(self, file_id: str, app_id: str) -> tuple[MessageFile, UploadFile]:
|
||||
"""
|
||||
Validate that the file belongs to a message within the requesting app's context
|
||||
|
||||
Security validations performed:
|
||||
1. File exists in MessageFile table (was used in a conversation)
|
||||
2. Message belongs to the requesting app
|
||||
3. UploadFile record exists and is accessible
|
||||
4. File tenant matches app tenant (additional security layer)
|
||||
|
||||
Args:
|
||||
file_id: UUID of the file to validate
|
||||
app_id: UUID of the requesting app
|
||||
|
||||
Returns:
|
||||
Tuple of (MessageFile, UploadFile) if validation passes
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: File or related records not found
|
||||
FileAccessDeniedError: File does not belong to the app's context
|
||||
"""
|
||||
try:
|
||||
# Input validation
|
||||
if not file_id or not app_id:
|
||||
raise FileAccessDeniedError("Invalid file or app identifier")
|
||||
|
||||
# First, find the MessageFile that references this upload file
|
||||
message_file = db.session.query(MessageFile).where(MessageFile.upload_file_id == file_id).first()
|
||||
|
||||
if not message_file:
|
||||
raise FileNotFoundError("File not found in message context")
|
||||
|
||||
# Get the message and verify it belongs to the requesting app
|
||||
message = (
|
||||
db.session.query(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).first()
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise FileAccessDeniedError("File access denied: not owned by requesting app")
|
||||
|
||||
# Get the actual upload file record
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
|
||||
if not upload_file:
|
||||
raise FileNotFoundError("Upload file record not found")
|
||||
|
||||
# Additional security: verify tenant isolation
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
if app and upload_file.tenant_id != app.tenant_id:
|
||||
raise FileAccessDeniedError("File access denied: tenant mismatch")
|
||||
|
||||
return message_file, upload_file
|
||||
|
||||
except (FileNotFoundError, FileAccessDeniedError):
|
||||
# Re-raise our custom exceptions
|
||||
raise
|
||||
except Exception as e:
|
||||
# Log unexpected errors for debugging
|
||||
logger.exception(
|
||||
"Unexpected error during file ownership validation",
|
||||
extra={"file_id": file_id, "app_id": app_id, "error": str(e)},
|
||||
)
|
||||
raise FileAccessDeniedError("File access validation failed")
|
||||
|
||||
def _build_file_response(self, generator, upload_file: UploadFile, as_attachment: bool = False) -> Response:
|
||||
"""
|
||||
Build Flask Response object with appropriate headers for file streaming
|
||||
|
||||
Args:
|
||||
generator: File content generator from storage
|
||||
upload_file: UploadFile database record
|
||||
as_attachment: Whether to set Content-Disposition as attachment
|
||||
|
||||
Returns:
|
||||
Flask Response object with streaming file content
|
||||
"""
|
||||
response = Response(
|
||||
generator,
|
||||
mimetype=upload_file.mime_type,
|
||||
direct_passthrough=True,
|
||||
headers={},
|
||||
)
|
||||
|
||||
# Add Content-Length if known
|
||||
if upload_file.size and upload_file.size > 0:
|
||||
response.headers["Content-Length"] = str(upload_file.size)
|
||||
|
||||
# Add Accept-Ranges header for audio/video files to support seeking
|
||||
if upload_file.mime_type in [
|
||||
"audio/mpeg",
|
||||
"audio/wav",
|
||||
"audio/mp4",
|
||||
"audio/ogg",
|
||||
"audio/flac",
|
||||
"audio/aac",
|
||||
"video/mp4",
|
||||
"video/webm",
|
||||
"video/quicktime",
|
||||
"audio/x-m4a",
|
||||
]:
|
||||
response.headers["Accept-Ranges"] = "bytes"
|
||||
|
||||
# Set Content-Disposition for downloads
|
||||
if as_attachment and upload_file.name:
|
||||
encoded_filename = quote(upload_file.name)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
# Override content-type for downloads to force download
|
||||
response.headers["Content-Type"] = "application/octet-stream"
|
||||
|
||||
# Add caching headers for performance
|
||||
response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour
|
||||
|
||||
return response
|
||||
237
dify/api/controllers/service_api/app/message.py
Normal file
237
dify/api/controllers/service_api/app/message.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from flask_restx import Api, Namespace, Resource, fields, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import build_message_file_model
|
||||
from fields.message_fields import build_agent_thought_model, build_feedback_model
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
MessageNotExistsError,
|
||||
SuggestedQuestionsAfterAnswerDisabledError,
|
||||
)
|
||||
from services.message_service import MessageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define parsers for message APIs
|
||||
message_list_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID")
|
||||
.add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination")
|
||||
.add_argument(
|
||||
"limit",
|
||||
type=int_range(1, 100),
|
||||
required=False,
|
||||
default=20,
|
||||
location="args",
|
||||
help="Number of messages to return",
|
||||
)
|
||||
)
|
||||
|
||||
message_feedback_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating")
|
||||
.add_argument("content", type=str, location="json", help="Feedback content")
|
||||
)
|
||||
|
||||
feedback_list_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, default=1, location="args", help="Page number")
|
||||
.add_argument(
|
||||
"limit",
|
||||
type=int_range(1, 101),
|
||||
required=False,
|
||||
default=20,
|
||||
location="args",
|
||||
help="Number of feedbacks per page",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def build_message_model(api_or_ns: Api | Namespace):
|
||||
"""Build the message model for the API or Namespace."""
|
||||
# First build the nested models
|
||||
feedback_model = build_feedback_model(api_or_ns)
|
||||
agent_thought_model = build_agent_thought_model(api_or_ns)
|
||||
message_file_model = build_message_file_model(api_or_ns)
|
||||
|
||||
# Then build the message fields with nested models
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"message_files": fields.List(fields.Nested(message_file_model)),
|
||||
"feedback": fields.Nested(feedback_model, attribute="user_feedback", allow_null=True),
|
||||
"retriever_resources": fields.Raw(
|
||||
attribute=lambda obj: json.loads(obj.message_metadata).get("retriever_resources", [])
|
||||
if obj.message_metadata
|
||||
else []
|
||||
),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
}
|
||||
return api_or_ns.model("Message", message_fields)
|
||||
|
||||
|
||||
def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
|
||||
"""Build the message infinite scroll pagination model for the API or Namespace."""
|
||||
# Build the nested message model first
|
||||
message_model = build_message_model(api_or_ns)
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_model)),
|
||||
}
|
||||
return api_or_ns.model("MessageInfiniteScrollPagination", message_infinite_scroll_pagination_fields)
|
||||
|
||||
|
||||
@service_api_ns.route("/messages")
|
||||
class MessageListApi(Resource):
|
||||
@service_api_ns.expect(message_list_parser)
|
||||
@service_api_ns.doc("list_messages")
|
||||
@service_api_ns.doc(description="List messages in a conversation")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Messages retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Conversation or first message not found",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@service_api_ns.marshal_with(build_message_infinite_scroll_pagination_model(service_api_ns))
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
"""List messages in a conversation.
|
||||
|
||||
Retrieves messages with pagination support using first_id.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
args = message_list_parser.parse_args()
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except FirstMessageNotExistsError:
|
||||
raise NotFound("First Message Not Exists.")
|
||||
|
||||
|
||||
@service_api_ns.route("/messages/<uuid:message_id>/feedbacks")
|
||||
class MessageFeedbackApi(Resource):
|
||||
@service_api_ns.expect(message_feedback_parser)
|
||||
@service_api_ns.doc("create_message_feedback")
|
||||
@service_api_ns.doc(description="Submit feedback for a message")
|
||||
@service_api_ns.doc(params={"message_id": "Message ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Feedback submitted successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Message not found",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, message_id):
|
||||
"""Submit feedback for a message.
|
||||
|
||||
Allows users to rate messages as like/dislike and provide optional feedback content.
|
||||
"""
|
||||
message_id = str(message_id)
|
||||
|
||||
args = message_feedback_parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=end_user,
|
||||
rating=args.get("rating"),
|
||||
content=args.get("content"),
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@service_api_ns.route("/app/feedbacks")
|
||||
class AppGetFeedbacksApi(Resource):
|
||||
@service_api_ns.expect(feedback_list_parser)
|
||||
@service_api_ns.doc("get_app_feedbacks")
|
||||
@service_api_ns.doc(description="Get all feedbacks for the application")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Feedbacks retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
"""Get all feedbacks for the application.
|
||||
|
||||
Returns paginated list of all feedback submitted for messages in this app.
|
||||
"""
|
||||
args = feedback_list_parser.parse_args()
|
||||
feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=args["page"], limit=args["limit"])
|
||||
return {"data": feedbacks}
|
||||
|
||||
|
||||
@service_api_ns.route("/messages/<uuid:message_id>/suggested")
|
||||
class MessageSuggestedApi(Resource):
|
||||
@service_api_ns.doc("get_suggested_questions")
|
||||
@service_api_ns.doc(description="Get suggested follow-up questions for a message")
|
||||
@service_api_ns.doc(params={"message_id": "Message ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Suggested questions retrieved successfully",
|
||||
400: "Suggested questions feature is disabled",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Message not found",
|
||||
500: "Internal server error",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
|
||||
def get(self, app_model: App, end_user: EndUser, message_id):
|
||||
"""Get suggested follow-up questions for a message.
|
||||
|
||||
Returns AI-generated follow-up questions based on the message content.
|
||||
"""
|
||||
message_id = str(message_id)
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise BadRequest("Suggested Questions Is Disabled.")
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"result": "success", "data": questions}
|
||||
41
dify/api/controllers/service_api/app/site.py
Normal file
41
dify/api/controllers/service_api/app/site.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.fields import build_site_model
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, Site
|
||||
|
||||
|
||||
@service_api_ns.route("/site")
|
||||
class AppSiteApi(Resource):
|
||||
"""Resource for app sites."""
|
||||
|
||||
@service_api_ns.doc("get_app_site")
|
||||
@service_api_ns.doc(description="Get application site configuration")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Site configuration retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - site not found or tenant archived",
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_site_model(service_api_ns))
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app site info.
|
||||
|
||||
Returns the site configuration for the application including theme, icons, and text.
|
||||
"""
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
||||
assert app_model.tenant
|
||||
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
return site
|
||||
323
dify/api/controllers/service_api/app/workflow.py
Normal file
323
dify/api/controllers/service_api/app/workflow.py
Normal file
@@ -0,0 +1,323 @@
|
||||
import logging
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
from flask_restx import Api, Namespace, Resource, fields, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
CompletionRequestError,
|
||||
NotWorkflowAppError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from models.model import App, AppMode, EndUser
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define parsers for workflow APIs
|
||||
workflow_run_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
)
|
||||
|
||||
workflow_log_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keyword", type=str, location="args")
|
||||
.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
.add_argument("created_at__before", type=str, location="args")
|
||||
.add_argument("created_at__after", type=str, location="args")
|
||||
.add_argument(
|
||||
"created_by_end_user_session_id",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
.add_argument(
|
||||
"created_by_account",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
)
|
||||
|
||||
workflow_run_fields = {
|
||||
"id": fields.String,
|
||||
"workflow_id": fields.String,
|
||||
"status": fields.String,
|
||||
"inputs": fields.Raw,
|
||||
"outputs": fields.Raw,
|
||||
"error": fields.String,
|
||||
"total_steps": fields.Integer,
|
||||
"total_tokens": fields.Integer,
|
||||
"created_at": TimestampField,
|
||||
"finished_at": TimestampField,
|
||||
"elapsed_time": fields.Float,
|
||||
}
|
||||
|
||||
|
||||
def build_workflow_run_model(api_or_ns: Api | Namespace):
|
||||
"""Build the workflow run model for the API or Namespace."""
|
||||
return api_or_ns.model("WorkflowRun", workflow_run_fields)
|
||||
|
||||
|
||||
@service_api_ns.route("/workflows/run/<string:workflow_run_id>")
|
||||
class WorkflowRunDetailApi(Resource):
|
||||
@service_api_ns.doc("get_workflow_run_detail")
|
||||
@service_api_ns.doc(description="Get workflow run details")
|
||||
@service_api_ns.doc(params={"workflow_run_id": "Workflow run ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Workflow run details retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Workflow run not found",
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_workflow_run_model(service_api_ns))
|
||||
def get(self, app_model: App, workflow_run_id: str):
|
||||
"""Get a workflow task running detail.
|
||||
|
||||
Returns detailed information about a specific workflow run.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
# Use repository to get workflow run
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
workflow_run = workflow_run_repo.get_workflow_run_by_id(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
run_id=workflow_run_id,
|
||||
)
|
||||
return workflow_run
|
||||
|
||||
|
||||
@service_api_ns.route("/workflows/run")
|
||||
class WorkflowRunApi(Resource):
|
||||
@service_api_ns.expect(workflow_run_parser)
|
||||
@service_api_ns.doc("run_workflow")
|
||||
@service_api_ns.doc(description="Execute a workflow")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Workflow executed successfully",
|
||||
400: "Bad request - invalid parameters or workflow issues",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Workflow not found",
|
||||
429: "Rate limit exceeded",
|
||||
500: "Internal server error",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""Execute a workflow.
|
||||
|
||||
Runs a workflow with the provided inputs and returns the results.
|
||||
Supports both blocking and streaming response modes.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
args = workflow_run_parser.parse_args()
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
streaming = args.get("response_mode") == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@service_api_ns.route("/workflows/<string:workflow_id>/run")
|
||||
class WorkflowRunByIdApi(Resource):
|
||||
@service_api_ns.expect(workflow_run_parser)
|
||||
@service_api_ns.doc("run_workflow_by_id")
|
||||
@service_api_ns.doc(description="Execute a specific workflow by ID")
|
||||
@service_api_ns.doc(params={"workflow_id": "Workflow ID to execute"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Workflow executed successfully",
|
||||
400: "Bad request - invalid parameters or workflow issues",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Workflow not found",
|
||||
429: "Rate limit exceeded",
|
||||
500: "Internal server error",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, workflow_id: str):
|
||||
"""Run specific workflow by ID.
|
||||
|
||||
Executes a specific workflow version identified by its ID.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
args = workflow_run_parser.parse_args()
|
||||
|
||||
# Add workflow_id to args for AppGenerateService
|
||||
args["workflow_id"] = workflow_id
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
streaming = args.get("response_mode") == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except IsDraftWorkflowError as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except WorkflowIdFormatError as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@service_api_ns.route("/workflows/tasks/<string:task_id>/stop")
|
||||
class WorkflowTaskStopApi(Resource):
|
||||
@service_api_ns.doc("stop_workflow_task")
|
||||
@service_api_ns.doc(description="Stop a running workflow task")
|
||||
@service_api_ns.doc(params={"task_id": "Task ID to stop"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Task stopped successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Task not found",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
"""Stop a running workflow task."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@service_api_ns.route("/workflows/logs")
|
||||
class WorkflowAppLogApi(Resource):
|
||||
@service_api_ns.expect(workflow_log_parser)
|
||||
@service_api_ns.doc("get_workflow_logs")
|
||||
@service_api_ns.doc(description="Get workflow execution logs")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Logs retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_workflow_app_log_pagination_model(service_api_ns))
|
||||
def get(self, app_model: App):
|
||||
"""Get workflow app logs.
|
||||
|
||||
Returns paginated workflow execution logs with filtering options.
|
||||
"""
|
||||
args = workflow_log_parser.parse_args()
|
||||
|
||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = isoparse(args.created_at__before)
|
||||
|
||||
if args.created_at__after:
|
||||
args.created_at__after = isoparse(args.created_at__after)
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with Session(db.engine) as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
keyword=args.keyword,
|
||||
status=args.status,
|
||||
created_at_before=args.created_at__before,
|
||||
created_at_after=args.created_at__after,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
created_by_end_user_session_id=args.created_by_end_user_session_id,
|
||||
created_by_account=args.created_by_account,
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
704
dify/api/controllers/service_api/dataset/dataset.py
Normal file
704
dify/api/controllers/service_api/dataset/dataset.py
Normal file
@@ -0,0 +1,704 @@
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console.wraps import edit_permission_required
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
|
||||
from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
validate_dataset_token,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import build_dataset_tag_fields
|
||||
from libs.login import current_user
|
||||
from libs.validators import validate_description_length
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
# Define parsers for dataset operations
|
||||
dataset_create_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"description",
|
||||
type=validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
help="Invalid indexing technique.",
|
||||
)
|
||||
.add_argument(
|
||||
"permission",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
||||
help="Invalid permission.",
|
||||
required=False,
|
||||
nullable=False,
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_api_id",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="_validate_name",
|
||||
)
|
||||
.add_argument(
|
||||
"provider",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="vendor",
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_id",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
)
|
||||
.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
dataset_update_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument("description", location="json", store_missing=False, type=validate_description_length)
|
||||
.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help="Invalid indexing technique.",
|
||||
)
|
||||
.add_argument(
|
||||
"permission",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
||||
help="Invalid permission.",
|
||||
)
|
||||
.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
|
||||
.add_argument("embedding_model_provider", type=str, location="json", help="Invalid embedding model provider.")
|
||||
.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
|
||||
.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
|
||||
.add_argument(
|
||||
"external_retrieval_model",
|
||||
type=dict,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external retrieval model.",
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_id",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external knowledge id.",
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_api_id",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external knowledge api id.",
|
||||
)
|
||||
)
|
||||
|
||||
tag_create_parser = reqparse.RequestParser().add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=lambda x: x
|
||||
if x and 1 <= len(x) <= 50
|
||||
else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
|
||||
)
|
||||
|
||||
tag_update_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=lambda x: x
|
||||
if x and 1 <= len(x) <= 50
|
||||
else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
|
||||
)
|
||||
.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
|
||||
)
|
||||
|
||||
tag_delete_parser = reqparse.RequestParser().add_argument(
|
||||
"tag_id", nullable=False, required=True, help="Id of a tag.", type=str
|
||||
)
|
||||
|
||||
tag_binding_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
|
||||
.add_argument(
|
||||
"target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
|
||||
)
|
||||
)
|
||||
|
||||
tag_unbinding_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
|
||||
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets")
|
||||
class DatasetListApi(DatasetApiResource):
|
||||
"""Resource for datasets."""
|
||||
|
||||
@service_api_ns.doc("list_datasets")
|
||||
@service_api_ns.doc(description="List all datasets")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Datasets retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id):
|
||||
"""Resource for getting datasets."""
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
tag_ids = request.args.getlist("tag_ids")
|
||||
include_all = request.args.get("include_all", default="false").lower() == "true"
|
||||
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, limit, tenant_id, current_user, search, tag_ids, include_all
|
||||
)
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
assert cid is not None
|
||||
configurations = provider_manager.get_configurations(tenant_id=cid)
|
||||
|
||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||
|
||||
model_names = []
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
|
||||
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
item["embedding_available"] = True
|
||||
else:
|
||||
item["embedding_available"] = False
|
||||
else:
|
||||
item["embedding_available"] = True
|
||||
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(dataset_create_parser)
|
||||
@service_api_ns.doc("create_dataset")
|
||||
@service_api_ns.doc(description="Create a new dataset")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Dataset created successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
400: "Bad request - invalid parameters",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id):
|
||||
"""Resource for creating datasets."""
|
||||
args = dataset_create_parser.parse_args()
|
||||
|
||||
embedding_model_provider = args.get("embedding_model_provider")
|
||||
embedding_model = args.get("embedding_model")
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = args.get("retrieval_model")
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
try:
|
||||
assert isinstance(current_user, Account)
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=args["name"],
|
||||
description=args["description"],
|
||||
indexing_technique=args["indexing_technique"],
|
||||
account=current_user,
|
||||
permission=args["permission"],
|
||||
provider=args["provider"],
|
||||
external_knowledge_api_id=args["external_knowledge_api_id"],
|
||||
external_knowledge_id=args["external_knowledge_id"],
|
||||
embedding_model_provider=args["embedding_model_provider"],
|
||||
embedding_model_name=args["embedding_model"],
|
||||
retrieval_model=RetrievalModel.model_validate(args["retrieval_model"])
|
||||
if args["retrieval_model"] is not None
|
||||
else None,
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return marshal(dataset, dataset_detail_fields), 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>")
|
||||
class DatasetApi(DatasetApiResource):
|
||||
"""Resource for dataset."""
|
||||
|
||||
@service_api_ns.doc("get_dataset")
|
||||
@service_api_ns.doc(description="Get a specific dataset by ID")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Dataset retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - insufficient permissions",
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
def get(self, _, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
assert cid is not None
|
||||
configurations = provider_manager.get_configurations(tenant_id=cid)
|
||||
|
||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||
|
||||
model_names = []
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
if data.get("indexing_technique") == "high_quality":
|
||||
item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}"
|
||||
if item_model in model_names:
|
||||
data["embedding_available"] = True
|
||||
else:
|
||||
data["embedding_available"] = False
|
||||
else:
|
||||
data["embedding_available"] = True
|
||||
|
||||
# force update search method to keyword_search if indexing_technique is economic
|
||||
retrieval_model_dict = data.get("retrieval_model_dict")
|
||||
if retrieval_model_dict:
|
||||
retrieval_model_dict["search_method"] = "keyword_search"
|
||||
|
||||
if data.get("permission") == "partial_members":
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({"partial_member_list": part_users_list})
|
||||
|
||||
return data, 200
|
||||
|
||||
@service_api_ns.expect(dataset_update_parser)
|
||||
@service_api_ns.doc("update_dataset")
|
||||
@service_api_ns.doc(description="Update an existing dataset")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Dataset updated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - insufficient permissions",
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, _, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
args = dataset_update_parser.parse_args()
|
||||
data = request.get_json()
|
||||
|
||||
# check embedding model setting
|
||||
embedding_model_provider = data.get("embedding_model_provider")
|
||||
embedding_model = data.get("embedding_model")
|
||||
if data.get("indexing_technique") == "high_quality" or embedding_model_provider:
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(
|
||||
dataset.tenant_id, embedding_model_provider, embedding_model
|
||||
)
|
||||
|
||||
retrieval_model = data.get("retrieval_model")
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
dataset.tenant_id,
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
DatasetPermissionService.check_permission(
|
||||
current_user, dataset, data.get("permission"), data.get("partial_member_list")
|
||||
)
|
||||
|
||||
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
|
||||
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
assert isinstance(current_user, Account)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
if data.get("partial_member_list") and data.get("permission") == "partial_members":
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
tenant_id, dataset_id_str, data.get("partial_member_list")
|
||||
)
|
||||
# clear partial member list when permission is only_me or all_team_members
|
||||
elif (
|
||||
data.get("permission") == DatasetPermissionEnum.ONLY_ME
|
||||
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
|
||||
):
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
|
||||
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
result_data.update({"partial_member_list": partial_member_list})
|
||||
|
||||
return result_data, 200
|
||||
|
||||
@service_api_ns.doc("delete_dataset")
|
||||
@service_api_ns.doc(description="Delete a dataset")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
204: "Dataset deleted successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset not found",
|
||||
409: "Conflict - dataset is in use",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, _, dataset_id):
|
||||
"""
|
||||
Deletes a dataset given its ID.
|
||||
|
||||
Args:
|
||||
_: ignore
|
||||
dataset_id (UUID): The ID of the dataset to be deleted.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with a key 'result' and a value 'success'
|
||||
if the dataset was successfully deleted. Omitted in HTTP response.
|
||||
int: HTTP status code 204 indicating that the operation was successful.
|
||||
|
||||
Raises:
|
||||
NotFound: If the dataset with the given ID does not exist.
|
||||
"""
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
try:
|
||||
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
return 204
|
||||
else:
|
||||
raise NotFound("Dataset not found.")
|
||||
except services.errors.dataset.DatasetInUseError:
|
||||
raise DatasetInUseError()
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/status/<string:action>")
|
||||
class DocumentStatusApi(DatasetApiResource):
|
||||
"""Resource for batch document status operations."""
|
||||
|
||||
@service_api_ns.doc("update_document_status")
|
||||
@service_api_ns.doc(description="Batch update document status")
|
||||
@service_api_ns.doc(
|
||||
params={
|
||||
"dataset_id": "Dataset ID",
|
||||
"action": "Action to perform: 'enable', 'disable', 'archive', or 'un_archive'",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document status updated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - insufficient permissions",
|
||||
404: "Dataset not found",
|
||||
400: "Bad request - invalid action",
|
||||
}
|
||||
)
|
||||
def patch(self, tenant_id, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
|
||||
"""
|
||||
Batch update document status.
|
||||
|
||||
Args:
|
||||
tenant_id: tenant id
|
||||
dataset_id: dataset id
|
||||
action: action to perform (Literal["enable", "disable", "archive", "un_archive"])
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with a key 'result' and a value 'success'
|
||||
int: HTTP status code 200 indicating that the operation was successful.
|
||||
|
||||
Raises:
|
||||
NotFound: If the dataset with the given ID does not exist.
|
||||
Forbidden: If the user does not have permission.
|
||||
InvalidActionError: If the action is invalid or cannot be performed.
|
||||
"""
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# Check user's permission
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# Check dataset model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
# Get document IDs from request body
|
||||
data = request.get_json()
|
||||
document_ids = data.get("document_ids", [])
|
||||
|
||||
try:
|
||||
DocumentService.batch_update_document_status(dataset, document_ids, action, current_user)
|
||||
except services.errors.document.DocumentIndexingError as e:
|
||||
raise InvalidActionError(str(e))
|
||||
except ValueError as e:
|
||||
raise InvalidActionError(str(e))
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/tags")
|
||||
class DatasetTagsApi(DatasetApiResource):
|
||||
@service_api_ns.doc("list_dataset_tags")
|
||||
@service_api_ns.doc(description="Get all knowledge type tags")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Tags retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
def get(self, _, dataset_id):
|
||||
"""Get all knowledge type tags."""
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
assert cid is not None
|
||||
tags = TagService.get_tags("knowledge", cid)
|
||||
|
||||
return tags, 200
|
||||
|
||||
@service_api_ns.expect(tag_create_parser)
|
||||
@service_api_ns.doc("create_dataset_tag")
|
||||
@service_api_ns.doc(description="Add a knowledge type tag")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Tag created successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
"""Add a knowledge type tag."""
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = tag_create_parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
tag = TagService.save_tags(args)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(tag_update_parser)
|
||||
@service_api_ns.doc("update_dataset_tag")
|
||||
@service_api_ns.doc(description="Update a knowledge type tag")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Tag updated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
@validate_dataset_token
|
||||
def patch(self, _, dataset_id):
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = tag_update_parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
tag_id = args["tag_id"]
|
||||
tag = TagService.update_tags(args, tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(tag_delete_parser)
|
||||
@service_api_ns.doc("delete_dataset_tag")
|
||||
@service_api_ns.doc(description="Delete a knowledge type tag")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
204: "Tag deleted successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
@edit_permission_required
|
||||
def delete(self, _, dataset_id):
|
||||
"""Delete a knowledge type tag."""
|
||||
args = tag_delete_parser.parse_args()
|
||||
TagService.delete_tag(args["tag_id"])
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/tags/binding")
|
||||
class DatasetTagBindingApi(DatasetApiResource):
|
||||
@service_api_ns.expect(tag_binding_parser)
|
||||
@service_api_ns.doc("bind_dataset_tags")
|
||||
@service_api_ns.doc(description="Bind tags to a dataset")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
204: "Tags bound successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = tag_binding_parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
TagService.save_tag_binding(args)
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/tags/unbinding")
|
||||
class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
@service_api_ns.expect(tag_unbinding_parser)
|
||||
@service_api_ns.doc("unbind_dataset_tag")
|
||||
@service_api_ns.doc(description="Unbind a tag from a dataset")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
204: "Tag unbound successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = tag_unbinding_parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
TagService.delete_tag_binding(args)
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/tags")
|
||||
class DatasetTagsBindingStatusApi(DatasetApiResource):
|
||||
@service_api_ns.doc("get_dataset_tags_binding_status")
|
||||
@service_api_ns.doc(description="Get tags bound to a specific dataset")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Tags retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
def get(self, _, *args, **kwargs):
|
||||
"""Get all knowledge type tags."""
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
|
||||
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
|
||||
response = {"data": tags_list, "total": len(tags)}
|
||||
return response, 200
|
||||
705
dify/api/controllers/service_api/dataset/document.py
Normal file
705
dify/api/controllers/service_api/dataset/document.py
Normal file
@@ -0,0 +1,705 @@
|
||||
import json
|
||||
from typing import Self
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal, reqparse
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import desc, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.dataset.error import (
|
||||
ArchivedDocumentImmutableError,
|
||||
DocumentIndexingError,
|
||||
InvalidMetadataError,
|
||||
)
|
||||
from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
)
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from extensions.ext_database import db
|
||||
from fields.document_fields import document_fields, document_status_fields
|
||||
from libs.login import current_user
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
|
||||
from services.file_service import FileService
|
||||
|
||||
# Define parsers for document operations
|
||||
document_text_create_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("text", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("process_rule", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("original_document_id", type=str, required=False, location="json")
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
|
||||
)
|
||||
.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class DocumentTextUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
text: str | None = None
|
||||
process_rule: ProcessRule | None = None
|
||||
doc_form: str = "text_model"
|
||||
doc_language: str = "English"
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_text_and_name(self) -> Self:
|
||||
if self.text is not None and self.name is None:
|
||||
raise ValueError("name is required when text is provided")
|
||||
return self
|
||||
|
||||
|
||||
for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]:
|
||||
service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/document/create_by_text",
|
||||
"/datasets/<uuid:dataset_id>/document/create-by-text",
|
||||
)
|
||||
class DocumentAddByTextApi(DatasetApiResource):
|
||||
"""Resource for documents."""
|
||||
|
||||
@service_api_ns.expect(document_text_create_parser)
|
||||
@service_api_ns.doc("create_document_by_text")
|
||||
@service_api_ns.doc(description="Create a new document by providing text content")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document created successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
400: "Bad request - invalid parameters",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by text."""
|
||||
args = document_text_create_parser.parse_args()
|
||||
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if not dataset.indexing_technique and not args["indexing_technique"]:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
if text is None or name is None:
|
||||
raise ValueError("Both 'text' and 'name' must be non-null values.")
|
||||
|
||||
embedding_model_provider = args.get("embedding_model_provider")
|
||||
embedding_model = args.get("embedding_model")
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = args.get("retrieval_model")
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text",
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-text",
|
||||
)
|
||||
class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__], validate=True)
|
||||
@service_api_ns.doc("update_document_by_text")
|
||||
@service_api_ns.doc(description="Update an existing document by providing text content")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document updated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Document not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by text."""
|
||||
args = DocumentTextUpdate.model_validate(service_api_ns.payload).model_dump(exclude_unset=True)
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
retrieval_model = args.get("retrieval_model")
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if args.get("text"):
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/document/create_by_file",
|
||||
"/datasets/<uuid:dataset_id>/document/create-by-file",
|
||||
)
|
||||
class DocumentAddByFileApi(DatasetApiResource):
|
||||
"""Resource for documents."""
|
||||
|
||||
@service_api_ns.doc("create_document_by_file")
|
||||
@service_api_ns.doc(description="Create a new document by uploading a file")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document created successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
400: "Bad request - invalid file or parameters",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by upload file."""
|
||||
args = {}
|
||||
if "data" in request.form:
|
||||
args = json.loads(request.form["data"])
|
||||
if "doc_form" not in args:
|
||||
args["doc_form"] = "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if dataset.provider == "external":
|
||||
raise ValueError("External datasets are not supported.")
|
||||
|
||||
indexing_technique = args.get("indexing_technique") or dataset.indexing_technique
|
||||
if not indexing_technique:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
args["indexing_technique"] = indexing_technique
|
||||
|
||||
if "embedding_model_provider" in args:
|
||||
DatasetService.check_embedding_model_setting(
|
||||
tenant_id, args["embedding_model_provider"], args["embedding_model"]
|
||||
)
|
||||
if (
|
||||
"retrieval_model" in args
|
||||
and args["retrieval_model"].get("reranking_model")
|
||||
and args["retrieval_model"].get("reranking_model").get("reranking_provider_name")
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
args["retrieval_model"].get("reranking_model").get("reranking_provider_name"),
|
||||
args["retrieval_model"].get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None
|
||||
if not knowledge_config.original_document_id and not dataset_process_rule and not knowledge_config.process_rule:
|
||||
raise ValueError("process_rule is required.")
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset_process_rule,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file",
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-file",
|
||||
)
|
||||
class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
|
||||
@service_api_ns.doc("update_document_by_file")
|
||||
@service_api_ns.doc(description="Update an existing document by uploading a file")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document updated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Document not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Update document by upload file."""
|
||||
args = {}
|
||||
if "data" in request.form:
|
||||
args = json.loads(request.form["data"])
|
||||
if "doc_form" not in args:
|
||||
args["doc_form"] = "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if dataset.provider == "external":
|
||||
raise ValueError("External datasets are not supported.")
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if "file" in request.files:
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, _ = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents")
|
||||
class DocumentListApi(DatasetApiResource):
|
||||
@service_api_ns.doc("list_documents")
|
||||
@service_api_ns.doc(description="List all documents in a dataset")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Documents retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
status = request.args.get("status", default=None, type=str)
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
|
||||
|
||||
if status:
|
||||
query = DocumentService.apply_display_status_filter(query, status)
|
||||
|
||||
if search:
|
||||
search = f"%{search}%"
|
||||
query = query.where(Document.name.like(search))
|
||||
|
||||
query = query.order_by(desc(Document.created_at), desc(Document.position))
|
||||
|
||||
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
documents = paginated_documents.items
|
||||
|
||||
response = {
|
||||
"data": marshal(documents, document_fields),
|
||||
"has_more": len(documents) == limit,
|
||||
"limit": limit,
|
||||
"total": paginated_documents.total,
|
||||
"page": page,
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status")
|
||||
class DocumentIndexingStatusApi(DatasetApiResource):
|
||||
@service_api_ns.doc("get_document_indexing_status")
|
||||
@service_api_ns.doc(description="Get indexing status for documents in a batch")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "batch": "Batch ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Indexing status retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset or documents not found",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id, dataset_id, batch):
|
||||
dataset_id = str(dataset_id)
|
||||
batch = str(batch)
|
||||
tenant_id = str(tenant_id)
|
||||
# get dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# get documents
|
||||
documents = DocumentService.get_batch_documents(dataset_id, batch)
|
||||
if not documents:
|
||||
raise NotFound("Documents not found.")
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.count()
|
||||
)
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
"id": document.id,
|
||||
"indexing_status": "paused" if document.is_paused else document.indexing_status,
|
||||
"processing_started_at": document.processing_started_at,
|
||||
"parsing_completed_at": document.parsing_completed_at,
|
||||
"cleaning_completed_at": document.cleaning_completed_at,
|
||||
"splitting_completed_at": document.splitting_completed_at,
|
||||
"completed_at": document.completed_at,
|
||||
"paused_at": document.paused_at,
|
||||
"error": document.error,
|
||||
"stopped_at": document.stopped_at,
|
||||
"completed_segments": completed_segments,
|
||||
"total_segments": total_segments,
|
||||
}
|
||||
documents_status.append(marshal(document_dict, document_status_fields))
|
||||
data = {"data": documents_status}
|
||||
return data
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
class DocumentApi(DatasetApiResource):
|
||||
METADATA_CHOICES = {"all", "only", "without"}
|
||||
|
||||
@service_api_ns.doc("get_document")
|
||||
@service_api_ns.doc(description="Get a specific document by ID")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - insufficient permissions",
|
||||
404: "Document not found",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
dataset = self.get_dataset(dataset_id, tenant_id)
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
if document.tenant_id != str(tenant_id):
|
||||
raise Forbidden("No permission.")
|
||||
|
||||
metadata = request.args.get("metadata", "all")
|
||||
if metadata not in self.METADATA_CHOICES:
|
||||
raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
|
||||
|
||||
if metadata == "only":
|
||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||
elif metadata == "without":
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
"position": document.position,
|
||||
"data_source_type": document.data_source_type,
|
||||
"data_source_info": data_source_info,
|
||||
"dataset_process_rule_id": document.dataset_process_rule_id,
|
||||
"dataset_process_rule": dataset_process_rules,
|
||||
"document_process_rule": document_process_rules,
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
"updated_at": int(document.updated_at.timestamp()) if document.updated_at else None,
|
||||
"indexing_latency": document.indexing_latency,
|
||||
"error": document.error,
|
||||
"enabled": document.enabled,
|
||||
"disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None,
|
||||
"disabled_by": document.disabled_by,
|
||||
"archived": document.archived,
|
||||
"segment_count": document.segment_count,
|
||||
"average_segment_length": document.average_segment_length,
|
||||
"hit_count": document.hit_count,
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
}
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
"position": document.position,
|
||||
"data_source_type": document.data_source_type,
|
||||
"data_source_info": data_source_info,
|
||||
"dataset_process_rule_id": document.dataset_process_rule_id,
|
||||
"dataset_process_rule": dataset_process_rules,
|
||||
"document_process_rule": document_process_rules,
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
"updated_at": int(document.updated_at.timestamp()) if document.updated_at else None,
|
||||
"indexing_latency": document.indexing_latency,
|
||||
"error": document.error,
|
||||
"enabled": document.enabled,
|
||||
"disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None,
|
||||
"disabled_by": document.disabled_by,
|
||||
"archived": document.archived,
|
||||
"doc_type": document.doc_type,
|
||||
"doc_metadata": document.doc_metadata_details,
|
||||
"segment_count": document.segment_count,
|
||||
"average_segment_length": document.average_segment_length,
|
||||
"hit_count": document.hit_count,
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
@service_api_ns.doc("delete_document")
|
||||
@service_api_ns.doc(description="Delete a document")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
204: "Document deleted successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - document is archived",
|
||||
404: "Document not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id, dataset_id, document_id):
|
||||
"""Delete document."""
|
||||
document_id = str(document_id)
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
# get dataset info
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
# 404 if document not found
|
||||
if document is None:
|
||||
raise NotFound("Document Not Exists.")
|
||||
|
||||
# 403 if document is archived
|
||||
if DocumentService.check_archived(document):
|
||||
raise ArchivedDocumentImmutableError()
|
||||
|
||||
try:
|
||||
# delete document
|
||||
DocumentService.delete_document(document)
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return 204
|
||||
55
dify/api/controllers/service_api/dataset/error.py
Normal file
55
dify/api/controllers/service_api/dataset/error.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class DatasetNotInitializedError(BaseHTTPException):
|
||||
error_code = "dataset_not_initialized"
|
||||
description = "The dataset is still being initialized or indexing. Please wait a moment."
|
||||
code = 400
|
||||
|
||||
|
||||
class ArchivedDocumentImmutableError(BaseHTTPException):
|
||||
error_code = "archived_document_immutable"
|
||||
description = "The archived document is not editable."
|
||||
code = 403
|
||||
|
||||
|
||||
class DatasetNameDuplicateError(BaseHTTPException):
|
||||
error_code = "dataset_name_duplicate"
|
||||
description = "The dataset name already exists. Please modify your dataset name."
|
||||
code = 409
|
||||
|
||||
|
||||
class InvalidActionError(BaseHTTPException):
|
||||
error_code = "invalid_action"
|
||||
description = "Invalid action."
|
||||
code = 400
|
||||
|
||||
|
||||
class DocumentAlreadyFinishedError(BaseHTTPException):
|
||||
error_code = "document_already_finished"
|
||||
description = "The document has been processed. Please refresh the page or go to the document details."
|
||||
code = 400
|
||||
|
||||
|
||||
class DocumentIndexingError(BaseHTTPException):
|
||||
error_code = "document_indexing"
|
||||
description = "The document is being processed and cannot be edited."
|
||||
code = 400
|
||||
|
||||
|
||||
class InvalidMetadataError(BaseHTTPException):
|
||||
error_code = "invalid_metadata"
|
||||
description = "The metadata content is incorrect. Please check and verify."
|
||||
code = 400
|
||||
|
||||
|
||||
class DatasetInUseError(BaseHTTPException):
|
||||
error_code = "dataset_in_use"
|
||||
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
|
||||
code = 409
|
||||
|
||||
|
||||
class PipelineRunError(BaseHTTPException):
|
||||
error_code = "pipeline_run_error"
|
||||
description = "An error occurred while running the pipeline."
|
||||
code = 500
|
||||
30
dify/api/controllers/service_api/dataset/hit_testing.py
Normal file
30
dify/api/controllers/service_api/dataset/hit_testing.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve")
|
||||
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||
@service_api_ns.doc("dataset_hit_testing")
|
||||
@service_api_ns.doc(description="Perform hit testing on a dataset")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Hit testing results",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Perform hit testing on a dataset.
|
||||
|
||||
Tests retrieval performance for the specified dataset.
|
||||
"""
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||
args = self.parse_args()
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
return self.perform_hit_testing(dataset, args)
|
||||
203
dify/api/controllers/service_api/dataset/metadata.py
Normal file
203
dify/api/controllers/service_api/dataset/metadata.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
MetadataArgs,
|
||||
MetadataOperationData,
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
# Define parsers for metadata APIs
|
||||
metadata_create_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("type", type=str, required=True, nullable=False, location="json", help="Metadata type")
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json", help="Metadata name")
|
||||
)
|
||||
|
||||
metadata_update_parser = reqparse.RequestParser().add_argument(
|
||||
"name", type=str, required=True, nullable=False, location="json", help="New metadata name"
|
||||
)
|
||||
|
||||
document_metadata_parser = reqparse.RequestParser().add_argument(
|
||||
"operation_data", type=list, required=True, nullable=False, location="json", help="Metadata operation data"
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
||||
class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
@service_api_ns.expect(metadata_create_parser)
|
||||
@service_api_ns.doc("create_dataset_metadata")
|
||||
@service_api_ns.doc(description="Create metadata for a dataset")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
201: "Metadata created successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create metadata for a dataset."""
|
||||
args = metadata_create_parser.parse_args()
|
||||
metadata_args = MetadataArgs.model_validate(args)
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
|
||||
return marshal(metadata, dataset_metadata_fields), 201
|
||||
|
||||
@service_api_ns.doc("get_dataset_metadata")
|
||||
@service_api_ns.doc(description="Get all metadata for a dataset")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Metadata retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id, dataset_id):
|
||||
"""Get all metadata for a dataset."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
return MetadataService.get_dataset_metadatas(dataset), 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
||||
class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
@service_api_ns.expect(metadata_update_parser)
|
||||
@service_api_ns.doc("update_dataset_metadata")
|
||||
@service_api_ns.doc(description="Update metadata name")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Metadata updated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset or metadata not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id, dataset_id, metadata_id):
|
||||
"""Update metadata name."""
|
||||
args = metadata_update_parser.parse_args()
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args["name"])
|
||||
return marshal(metadata, dataset_metadata_fields), 200
|
||||
|
||||
@service_api_ns.doc("delete_dataset_metadata")
|
||||
@service_api_ns.doc(description="Delete metadata")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
204: "Metadata deleted successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset or metadata not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id, dataset_id, metadata_id):
|
||||
"""Delete metadata."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in")
|
||||
class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
|
||||
@service_api_ns.doc("get_built_in_fields")
|
||||
@service_api_ns.doc(description="Get all built-in metadata fields")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Built-in fields retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id, dataset_id):
|
||||
"""Get all built-in metadata fields."""
|
||||
built_in_fields = MetadataService.get_built_in_fields()
|
||||
return {"fields": built_in_fields}, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
|
||||
class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
||||
@service_api_ns.doc("toggle_built_in_field")
|
||||
@service_api_ns.doc(description="Enable or disable built-in metadata field")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "action": "Action to perform: 'enable' or 'disable'"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Action completed successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, action: Literal["enable", "disable"]):
|
||||
"""Enable or disable built-in metadata field."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
if action == "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
elif action == "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
|
||||
class DocumentMetadataEditServiceApi(DatasetApiResource):
|
||||
@service_api_ns.expect(document_metadata_parser)
|
||||
@service_api_ns.doc("update_documents_metadata")
|
||||
@service_api_ns.doc(description="Update metadata for multiple documents")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Documents metadata updated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Update metadata for multiple documents."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
args = document_metadata_parser.parse_args()
|
||||
metadata_args = MetadataOperationData.model_validate(args)
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
@@ -0,0 +1,246 @@
|
||||
import string
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import reqparse
|
||||
from flask_restx.reqparse import ParseResult, RequestParser
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import PipelineRunError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from libs import helper
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.engine import db
|
||||
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||
from services.file_service import FileService
|
||||
from services.rag_pipeline.entity.pipeline_service_api_entities import DatasourceNodeRunApiEntity
|
||||
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins")
|
||||
class DatasourcePluginsApi(DatasetApiResource):
|
||||
"""Resource for datasource plugins."""
|
||||
|
||||
@service_api_ns.doc(shortcut="list_rag_pipeline_datasource_plugins")
|
||||
@service_api_ns.doc(description="List all datasource plugins for a rag pipeline")
|
||||
@service_api_ns.doc(
|
||||
path={
|
||||
"dataset_id": "Dataset ID",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
params={
|
||||
"is_published": "Whether to get published or draft datasource plugins "
|
||||
"(true for published, false for draft, default: true)"
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Datasource plugins retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: str):
|
||||
"""Resource for getting datasource plugins."""
|
||||
# Get query parameter to determine published or draft
|
||||
is_published: bool = request.args.get("is_published", default=True, type=bool)
|
||||
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins(
|
||||
tenant_id=tenant_id, dataset_id=dataset_id, is_published=is_published
|
||||
)
|
||||
return datasource_plugins, 200
|
||||
|
||||
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource/nodes/{string:node_id}/run")
|
||||
class DatasourceNodeRunApi(DatasetApiResource):
|
||||
"""Resource for datasource node run."""
|
||||
|
||||
@service_api_ns.doc(shortcut="pipeline_datasource_node_run")
|
||||
@service_api_ns.doc(description="Run a datasource node for a rag pipeline")
|
||||
@service_api_ns.doc(
|
||||
path={
|
||||
"dataset_id": "Dataset ID",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
body={
|
||||
"inputs": "User input variables",
|
||||
"datasource_type": "Datasource type, e.g. online_document",
|
||||
"credential_id": "Credential ID",
|
||||
"is_published": "Whether to get published or draft datasource plugins "
|
||||
"(true for published, false for draft, default: true)",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Datasource node run successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
def post(self, tenant_id: str, dataset_id: str, node_id: str):
|
||||
"""Resource for getting datasource plugins."""
|
||||
# Get query parameter to determine published or draft
|
||||
parser: RequestParser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=False, location="json")
|
||||
.add_argument("is_published", type=bool, required=True, location="json")
|
||||
)
|
||||
args: ParseResult = parser.parse_args()
|
||||
|
||||
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(args)
|
||||
assert isinstance(current_user, Account)
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
return helper.compact_generate_response(
|
||||
PipelineGenerator.convert_to_event_stream(
|
||||
rag_pipeline_service.run_datasource_workflow_node(
|
||||
pipeline=pipeline,
|
||||
node_id=node_id,
|
||||
user_inputs=datasource_node_run_api_entity.inputs,
|
||||
account=current_user,
|
||||
datasource_type=datasource_node_run_api_entity.datasource_type,
|
||||
is_published=datasource_node_run_api_entity.is_published,
|
||||
credential_id=datasource_node_run_api_entity.credential_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/run")
|
||||
class PipelineRunApi(DatasetApiResource):
|
||||
"""Resource for datasource node run."""
|
||||
|
||||
@service_api_ns.doc(shortcut="pipeline_datasource_node_run")
|
||||
@service_api_ns.doc(description="Run a datasource node for a rag pipeline")
|
||||
@service_api_ns.doc(
|
||||
path={
|
||||
"dataset_id": "Dataset ID",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
body={
|
||||
"inputs": "User input variables",
|
||||
"datasource_type": "Datasource type, e.g. online_document",
|
||||
"datasource_info_list": "Datasource info list",
|
||||
"start_node_id": "Start node ID",
|
||||
"is_published": "Whether to get published or draft datasource plugins "
|
||||
"(true for published, false for draft, default: true)",
|
||||
"streaming": "Whether to stream the response(streaming or blocking), default: streaming",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Pipeline run successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
def post(self, tenant_id: str, dataset_id: str):
|
||||
"""Resource for running a rag pipeline."""
|
||||
parser: RequestParser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
.add_argument("is_published", type=bool, required=True, default=True, location="json")
|
||||
.add_argument(
|
||||
"response_mode",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["streaming", "blocking"],
|
||||
default="blocking",
|
||||
location="json",
|
||||
)
|
||||
)
|
||||
args: ParseResult = parser.parse_args()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
try:
|
||||
response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate(
|
||||
pipeline=pipeline,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.PUBLISHED if args.get("is_published") else InvokeFrom.DEBUGGER,
|
||||
streaming=args.get("response_mode") == "streaming",
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except Exception as ex:
|
||||
raise PipelineRunError(description=str(ex))
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/pipeline/file-upload")
|
||||
class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
|
||||
"""Resource for uploading a file to a knowledgebase pipeline."""
|
||||
|
||||
@service_api_ns.doc(shortcut="knowledgebase_pipeline_file_upload")
|
||||
@service_api_ns.doc(description="Upload a file to a knowledgebase pipeline")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
201: "File uploaded successfully",
|
||||
400: "Bad request - no file or invalid file",
|
||||
401: "Unauthorized - invalid API token",
|
||||
413: "File too large",
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
def post(self, tenant_id: str):
|
||||
"""Upload a file for use in conversations.
|
||||
|
||||
Accepts a single file upload via multipart/form-data.
|
||||
"""
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
file = request.files["file"]
|
||||
if not file.mimetype:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("Invalid user account")
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at,
|
||||
}, 201
|
||||
543
dify/api/controllers/service_api/dataset/segment.py
Normal file
543
dify/api/controllers/service_api/dataset/segment.py
Normal file
@@ -0,0 +1,543 @@
|
||||
from flask import request
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
cloud_edition_billing_knowledge_limit_check,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
|
||||
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
|
||||
|
||||
# Define parsers for segment operations
|
||||
segment_create_parser = reqparse.RequestParser().add_argument(
|
||||
"segments", type=list, required=False, nullable=True, location="json"
|
||||
)
|
||||
|
||||
segment_list_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("status", type=str, action="append", default=[], location="args")
|
||||
.add_argument("keyword", type=str, default=None, location="args")
|
||||
)
|
||||
|
||||
segment_update_parser = reqparse.RequestParser().add_argument(
|
||||
"segment", type=dict, required=False, nullable=True, location="json"
|
||||
)
|
||||
|
||||
child_chunk_create_parser = reqparse.RequestParser().add_argument(
|
||||
"content", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
child_chunk_list_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("limit", type=int, default=20, location="args")
|
||||
.add_argument("keyword", type=str, default=None, location="args")
|
||||
.add_argument("page", type=int, default=1, location="args")
|
||||
)
|
||||
|
||||
child_chunk_update_parser = reqparse.RequestParser().add_argument(
|
||||
"content", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||
class SegmentApi(DatasetApiResource):
|
||||
"""Resource for segments."""
|
||||
|
||||
@service_api_ns.expect(segment_create_parser)
|
||||
@service_api_ns.doc("create_segments")
|
||||
@service_api_ns.doc(description="Create segments in a document")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Segments created successfully",
|
||||
400: "Bad request - segments data is missing",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset or document not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: str, document_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Create single segment."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if document.indexing_status != "completed":
|
||||
raise NotFound("Document is not completed.")
|
||||
if not document.enabled:
|
||||
raise NotFound("Document is disabled.")
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# validate args
|
||||
args = segment_create_parser.parse_args()
|
||||
if args["segments"] is not None:
|
||||
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
|
||||
if segments_limit > 0 and len(args["segments"]) > segments_limit:
|
||||
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
|
||||
|
||||
for args_item in args["segments"]:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
|
||||
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
|
||||
else:
|
||||
return {"error": "Segments is required"}, 400
|
||||
|
||||
@service_api_ns.expect(segment_list_parser)
|
||||
@service_api_ns.doc("list_segments")
|
||||
@service_api_ns.doc(description="List segments in a document")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Segments retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset or document not found",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: str, document_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Get segments."""
|
||||
# check dataset
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
args = segment_list_parser.parse_args()
|
||||
|
||||
segments, total = SegmentService.get_segments(
|
||||
document_id=document_id,
|
||||
tenant_id=current_tenant_id,
|
||||
status_list=args["status"],
|
||||
keyword=args["keyword"],
|
||||
page=page,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
response = {
|
||||
"data": marshal(segments, segment_fields),
|
||||
"doc_form": document.doc_form,
|
||||
"total": total,
|
||||
"has_more": len(segments) == limit,
|
||||
"limit": limit,
|
||||
"page": page,
|
||||
}
|
||||
|
||||
return response, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
|
||||
class DatasetSegmentApi(DatasetApiResource):
|
||||
@service_api_ns.doc("delete_segment")
|
||||
@service_api_ns.doc(description="Delete a specific segment")
|
||||
@service_api_ns.doc(
|
||||
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Segment ID to delete"}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
204: "Segment deleted successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset, document, or segment not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
return 204
|
||||
|
||||
@service_api_ns.expect(segment_update_parser)
|
||||
@service_api_ns.doc("update_segment")
|
||||
@service_api_ns.doc(description="Update a specific segment")
|
||||
@service_api_ns.doc(
|
||||
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Segment ID to update"}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Segment updated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset, document, or segment not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
# check embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# validate args
|
||||
args = segment_update_parser.parse_args()
|
||||
|
||||
updated_segment = SegmentService.update_segment(
|
||||
SegmentUpdateArgs.model_validate(args["segment"]), segment, document, dataset
|
||||
)
|
||||
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
@service_api_ns.doc("get_segment")
|
||||
@service_api_ns.doc(description="Get a specific segment by ID")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Segment retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset, document, or segment not found",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks"
|
||||
)
|
||||
class ChildChunkApi(DatasetApiResource):
|
||||
"""Resource for child chunks."""
|
||||
|
||||
@service_api_ns.expect(child_chunk_create_parser)
|
||||
@service_api_ns.doc("create_child_chunk")
|
||||
@service_api_ns.doc(description="Create a new child chunk for a segment")
|
||||
@service_api_ns.doc(
|
||||
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Parent segment ID"}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Child chunk created successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset, document, or segment not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Create child chunk."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
# validate args
|
||||
args = child_chunk_create_parser.parse_args()
|
||||
|
||||
try:
|
||||
child_chunk = SegmentService.create_child_chunk(args["content"], segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
|
||||
@service_api_ns.expect(child_chunk_list_parser)
|
||||
@service_api_ns.doc("list_child_chunks")
|
||||
@service_api_ns.doc(description="List child chunks for a segment")
|
||||
@service_api_ns.doc(
|
||||
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Parent segment ID"}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Child chunks retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset, document, or segment not found",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Get child chunks."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
args = child_chunk_list_parser.parse_args()
|
||||
|
||||
page = args["page"]
|
||||
limit = min(args["limit"], 100)
|
||||
keyword = args["keyword"]
|
||||
|
||||
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
|
||||
|
||||
return {
|
||||
"data": marshal(child_chunks.items, child_chunk_fields),
|
||||
"total": child_chunks.total,
|
||||
"total_pages": child_chunks.pages,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
}, 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>"
|
||||
)
|
||||
class DatasetChildChunkApi(DatasetApiResource):
|
||||
"""Resource for updating child chunks."""
|
||||
|
||||
@service_api_ns.doc("delete_child_chunk")
|
||||
@service_api_ns.doc(description="Delete a specific child chunk")
|
||||
@service_api_ns.doc(
|
||||
params={
|
||||
"dataset_id": "Dataset ID",
|
||||
"document_id": "Document ID",
|
||||
"segment_id": "Parent segment ID",
|
||||
"child_chunk_id": "Child chunk ID to delete",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
204: "Child chunk deleted successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset, document, segment, or child chunk not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Delete child chunk."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# validate segment belongs to the specified document
|
||||
if str(segment.document_id) != str(document_id):
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
# check child chunk
|
||||
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
# validate child chunk belongs to the specified segment
|
||||
if str(child_chunk.segment_id) != str(segment.id):
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
try:
|
||||
SegmentService.delete_child_chunk(child_chunk, dataset)
|
||||
except ChildChunkDeleteIndexServiceError as e:
|
||||
raise ChildChunkDeleteIndexError(str(e))
|
||||
|
||||
return 204
|
||||
|
||||
@service_api_ns.expect(child_chunk_update_parser)
|
||||
@service_api_ns.doc("update_child_chunk")
|
||||
@service_api_ns.doc(description="Update a specific child chunk")
|
||||
@service_api_ns.doc(
|
||||
params={
|
||||
"dataset_id": "Dataset ID",
|
||||
"document_id": "Document ID",
|
||||
"segment_id": "Parent segment ID",
|
||||
"child_chunk_id": "Child chunk ID to update",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Child chunk updated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset, document, segment, or child chunk not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Update child chunk."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# get document
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
# get segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# validate segment belongs to the specified document
|
||||
if str(segment.document_id) != str(document_id):
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# get child chunk
|
||||
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
# validate child chunk belongs to the specified segment
|
||||
if str(child_chunk.segment_id) != str(segment.id):
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
# validate args
|
||||
args = child_chunk_update_parser.parse_args()
|
||||
|
||||
try:
|
||||
child_chunk = SegmentService.update_child_chunk(args["content"], child_chunk, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
14
dify/api/controllers/service_api/index.py
Normal file
14
dify/api/controllers/service_api/index.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.service_api import service_api_ns
|
||||
|
||||
|
||||
@service_api_ns.route("/")
|
||||
class IndexApi(Resource):
|
||||
def get(self):
|
||||
return {
|
||||
"welcome": "Dify OpenAPI",
|
||||
"api_version": "v1",
|
||||
"server_version": dify_config.project.version,
|
||||
}
|
||||
32
dify/api/controllers/service_api/workspace/models.py
Normal file
32
dify/api/controllers/service_api/workspace/models.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_dataset_token
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
@service_api_ns.route("/workspaces/current/models/model-types/<string:model_type>")
|
||||
class ModelProviderAvailableModelApi(Resource):
|
||||
@service_api_ns.doc("get_available_models")
|
||||
@service_api_ns.doc(description="Get available models by model type")
|
||||
@service_api_ns.doc(params={"model_type": "Type of model to retrieve"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Models retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
def get(self, _, model_type: str):
|
||||
"""Get available models by model type.
|
||||
|
||||
Returns a list of available models for the specified model type.
|
||||
"""
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
|
||||
|
||||
return jsonable_encoder({"data": models})
|
||||
344
dify/api/controllers/service_api/wraps.py
Normal file
344
dify/api/controllers/service_api/wraps.py
Normal file
@@ -0,0 +1,344 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from enum import StrEnum, auto
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
from models.dataset import Dataset, RateLimitLog
|
||||
from models.model import ApiToken, App
|
||||
from services.end_user_service import EndUserService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class WhereisUserArg(StrEnum):
|
||||
"""
|
||||
Enum for whereis_user_arg.
|
||||
"""
|
||||
|
||||
QUERY = auto()
|
||||
JSON = auto()
|
||||
FORM = auto()
|
||||
|
||||
|
||||
class FetchUserArg(BaseModel):
|
||||
fetch_from: WhereisUserArg
|
||||
required: bool = False
|
||||
|
||||
|
||||
def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token("app")
|
||||
|
||||
app_model = db.session.query(App).where(App.id == api_token.app_id).first()
|
||||
if not app_model:
|
||||
raise Forbidden("The app no longer exists.")
|
||||
|
||||
if app_model.status != "normal":
|
||||
raise Forbidden("The app's status is abnormal.")
|
||||
|
||||
if not app_model.enable_api:
|
||||
raise Forbidden("The app's API service has been disabled.")
|
||||
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == app_model.tenant_id).first()
|
||||
if tenant is None:
|
||||
raise ValueError("Tenant does not exist.")
|
||||
if tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("The workspace's status is archived.")
|
||||
|
||||
kwargs["app_model"] = app_model
|
||||
|
||||
# If caller needs end-user context, attach EndUser to current_user
|
||||
if fetch_user_arg:
|
||||
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
|
||||
user_id = request.args.get("user")
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
|
||||
user_id = request.get_json().get("user")
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
|
||||
user_id = request.form.get("user")
|
||||
else:
|
||||
user_id = None
|
||||
|
||||
if not user_id and fetch_user_arg.required:
|
||||
raise ValueError("Arg user must be provided.")
|
||||
|
||||
if user_id:
|
||||
user_id = str(user_id)
|
||||
|
||||
end_user = EndUserService.get_or_create_end_user(app_model, user_id)
|
||||
kwargs["end_user"] = end_user
|
||||
|
||||
# Set EndUser as current logged-in user for flask_login.current_user
|
||||
current_app.login_manager._update_request_context_with_user(end_user) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=end_user) # type: ignore
|
||||
else:
|
||||
# For service API without end-user context, ensure an Account is logged in
|
||||
# so services relying on current_account_with_tenant() work correctly.
|
||||
tenant_owner_info = (
|
||||
db.session.query(Tenant, Account)
|
||||
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
|
||||
.join(Account, TenantAccountJoin.account_id == Account.id)
|
||||
.where(
|
||||
Tenant.id == app_model.tenant_id,
|
||||
TenantAccountJoin.role == "owner",
|
||||
Tenant.status == TenantStatus.NORMAL,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
if tenant_owner_info:
|
||||
tenant_model, account = tenant_owner_info
|
||||
account.current_tenant = tenant_model
|
||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
else:
|
||||
raise Unauthorized("Tenant owner account not found or tenant is not active.")
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
features = FeatureService.get_features(api_token.tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
vector_space = features.vector_space
|
||||
documents_upload_quota = features.documents_upload_quota
|
||||
|
||||
if resource == "members" and 0 < members.limit <= members.size:
|
||||
raise Forbidden("The number of members has reached the limit of your subscription.")
|
||||
elif resource == "apps" and 0 < apps.limit <= apps.size:
|
||||
raise Forbidden("The number of apps has reached the limit of your subscription.")
|
||||
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
||||
raise Forbidden("The capacity of the vector space has reached the limit of your subscription.")
|
||||
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
raise Forbidden("The number of documents has reached the limit of your subscription.")
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
features = FeatureService.get_features(api_token.tenant_id)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
raise Forbidden(
|
||||
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."
|
||||
)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
|
||||
if resource == "knowledge":
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(api_token.tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{api_token.tenant_id}"
|
||||
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
|
||||
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
||||
|
||||
request_count = redis_client.zcard(key)
|
||||
|
||||
if request_count > knowledge_rate_limit.limit:
|
||||
# add ratelimit record
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=api_token.tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
db.session.add(rate_limit_log)
|
||||
db.session.commit()
|
||||
raise Forbidden(
|
||||
"Sorry, you have reached the knowledge base request rate limit of your subscription."
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
return interceptor
|
||||
|
||||
|
||||
def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[T, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
# get url path dataset_id from positional args or kwargs
|
||||
# Flask passes URL path parameters as positional arguments
|
||||
dataset_id = None
|
||||
|
||||
# First try to get from kwargs (explicit parameter)
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
|
||||
# If not in kwargs, try to extract from positional args
|
||||
if not dataset_id and args:
|
||||
# For class methods: args[0] is self, args[1] is dataset_id (if exists)
|
||||
# Check if first arg is likely a class instance (has __dict__ or __class__)
|
||||
if len(args) > 1 and hasattr(args[0], "__dict__"):
|
||||
# This is a class method, dataset_id should be in args[1]
|
||||
potential_id = args[1]
|
||||
# Validate it's a string-like UUID, not another object
|
||||
try:
|
||||
# Try to convert to string and check if it's a valid UUID format
|
||||
str_id = str(potential_id)
|
||||
# Basic check: UUIDs are 36 chars with hyphens
|
||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||
dataset_id = str_id
|
||||
except:
|
||||
pass
|
||||
elif len(args) > 0:
|
||||
# Not a class method, check if args[0] looks like a UUID
|
||||
potential_id = args[0]
|
||||
try:
|
||||
str_id = str(potential_id)
|
||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||
dataset_id = str_id
|
||||
except:
|
||||
pass
|
||||
|
||||
# Validate dataset if dataset_id is provided
|
||||
if dataset_id:
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if not dataset.enable_api:
|
||||
raise Forbidden("Dataset api access is not enabled.")
|
||||
api_token = validate_and_get_api_token("dataset")
|
||||
tenant_account_join = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
.where(Tenant.id == api_token.tenant_id)
|
||||
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.role.in_(["owner"]))
|
||||
.where(Tenant.status == TenantStatus.NORMAL)
|
||||
.one_or_none()
|
||||
) # TODO: only owner information is required, so only one is returned.
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = db.session.query(Account).where(Account.id == ta.account_id).first()
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
else:
|
||||
raise Unauthorized("Tenant owner account does not exist.")
|
||||
else:
|
||||
raise Unauthorized("Tenant does not exist.")
|
||||
return view(api_token.tenant_id, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
|
||||
# if view is None, it means that the decorator is used without parentheses
|
||||
# use the decorator as a function for method_decorators
|
||||
return decorator
|
||||
|
||||
|
||||
def validate_and_get_api_token(scope: str | None = None):
|
||||
"""
|
||||
Validate and get API token.
|
||||
"""
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header is None or " " not in auth_header:
|
||||
raise Unauthorized("Authorization header must be provided and start with 'Bearer'")
|
||||
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
||||
|
||||
current_time = naive_utc_now()
|
||||
cutoff_time = current_time - timedelta(minutes=1)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
update_stmt = (
|
||||
update(ApiToken)
|
||||
.where(
|
||||
ApiToken.token == auth_token,
|
||||
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)),
|
||||
ApiToken.type == scope,
|
||||
)
|
||||
.values(last_used_at=current_time)
|
||||
.returning(ApiToken)
|
||||
)
|
||||
result = session.execute(update_stmt)
|
||||
api_token = result.scalar_one_or_none()
|
||||
|
||||
if not api_token:
|
||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||
api_token = session.scalar(stmt)
|
||||
if not api_token:
|
||||
raise Unauthorized("Access token is invalid")
|
||||
else:
|
||||
session.commit()
|
||||
|
||||
return api_token
|
||||
|
||||
|
||||
class DatasetApiResource(Resource):
|
||||
method_decorators = [validate_dataset_token]
|
||||
|
||||
def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset:
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
return dataset
|
||||
Reference in New Issue
Block a user