dify
This commit is contained in:
95
dify/api/core/plugin/utils/chunk_merger.py
Normal file
95
dify/api/core/plugin/utils/chunk_merger.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypeVar, Union
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
MessageType = TypeVar("MessageType", bound=Union[ToolInvokeMessage, AgentInvokeMessage])
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileChunk:
|
||||
"""
|
||||
Buffer for accumulating file chunks during streaming.
|
||||
"""
|
||||
|
||||
total_length: int
|
||||
bytes_written: int = field(default=0, init=False)
|
||||
data: bytearray = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.data = bytearray(self.total_length)
|
||||
|
||||
|
||||
def merge_blob_chunks(
|
||||
response: Generator[MessageType, None, None],
|
||||
max_file_size: int = 30 * 1024 * 1024,
|
||||
max_chunk_size: int = 8192,
|
||||
) -> Generator[MessageType, None, None]:
|
||||
"""
|
||||
Merge streaming blob chunks into complete blob messages.
|
||||
|
||||
This function processes a stream of plugin invoke messages, accumulating
|
||||
BLOB_CHUNK messages by their ID until the final chunk is received,
|
||||
then yielding a single complete BLOB message.
|
||||
|
||||
Args:
|
||||
response: Generator yielding messages that may include blob chunks
|
||||
max_file_size: Maximum allowed file size in bytes (default: 30MB)
|
||||
max_chunk_size: Maximum allowed chunk size in bytes (default: 8KB)
|
||||
|
||||
Yields:
|
||||
Messages from the response stream, with blob chunks merged into complete blobs
|
||||
|
||||
Raises:
|
||||
ValueError: If file size exceeds max_file_size or chunk size exceeds max_chunk_size
|
||||
"""
|
||||
files: dict[str, FileChunk] = {}
|
||||
|
||||
for resp in response:
|
||||
if resp.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
|
||||
assert isinstance(resp.message, ToolInvokeMessage.BlobChunkMessage)
|
||||
# Get blob chunk information
|
||||
chunk_id = resp.message.id
|
||||
total_length = resp.message.total_length
|
||||
blob_data = resp.message.blob
|
||||
is_end = resp.message.end
|
||||
|
||||
# Initialize buffer for this file if it doesn't exist
|
||||
if chunk_id not in files:
|
||||
files[chunk_id] = FileChunk(total_length)
|
||||
|
||||
# Check if file is too large (before appending)
|
||||
if files[chunk_id].bytes_written + len(blob_data) > max_file_size:
|
||||
# Delete the file if it's too large
|
||||
del files[chunk_id]
|
||||
raise ValueError(f"File is too large which reached the limit of {max_file_size / 1024 / 1024}MB")
|
||||
|
||||
# Check if single chunk is too large
|
||||
if len(blob_data) > max_chunk_size:
|
||||
raise ValueError(f"File chunk is too large which reached the limit of {max_chunk_size / 1024}KB")
|
||||
|
||||
# Append the blob data to the buffer
|
||||
files[chunk_id].data[files[chunk_id].bytes_written : files[chunk_id].bytes_written + len(blob_data)] = (
|
||||
blob_data
|
||||
)
|
||||
files[chunk_id].bytes_written += len(blob_data)
|
||||
|
||||
# If this is the final chunk, yield a complete blob message
|
||||
if is_end:
|
||||
# Create the appropriate message type based on the response type
|
||||
message_class = type(resp)
|
||||
merged_message = message_class(
|
||||
type=ToolInvokeMessage.MessageType.BLOB,
|
||||
message=ToolInvokeMessage.BlobMessage(
|
||||
blob=bytes(files[chunk_id].data[: files[chunk_id].bytes_written])
|
||||
),
|
||||
meta=resp.meta,
|
||||
)
|
||||
assert isinstance(merged_message, (ToolInvokeMessage, AgentInvokeMessage))
|
||||
yield merged_message # type: ignore
|
||||
# Clean up the buffer
|
||||
del files[chunk_id]
|
||||
else:
|
||||
yield resp
|
||||
21
dify/api/core/plugin/utils/converter.py
Normal file
21
dify/api/core/plugin/utils/converter.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Any
|
||||
|
||||
from core.file.models import File
|
||||
from core.tools.entities.tool_entities import ToolSelector
|
||||
|
||||
|
||||
def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
for parameter_name, parameter in parameters.items():
|
||||
if isinstance(parameter, File):
|
||||
parameters[parameter_name] = parameter.to_plugin_parameter()
|
||||
elif isinstance(parameter, list) and all(isinstance(p, File) for p in parameter):
|
||||
parameters[parameter_name] = []
|
||||
for p in parameter:
|
||||
parameters[parameter_name].append(p.to_plugin_parameter())
|
||||
elif isinstance(parameter, ToolSelector):
|
||||
parameters[parameter_name] = parameter.to_plugin_parameter()
|
||||
elif isinstance(parameter, list) and all(isinstance(p, ToolSelector) for p in parameter):
|
||||
parameters[parameter_name] = []
|
||||
for p in parameter:
|
||||
parameters[parameter_name].append(p.to_plugin_parameter())
|
||||
return parameters
|
||||
163
dify/api/core/plugin/utils/http_parser.py
Normal file
163
dify/api/core/plugin/utils/http_parser.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from io import BytesIO
|
||||
|
||||
from flask import Request, Response
|
||||
from werkzeug.datastructures import Headers
|
||||
|
||||
|
||||
def serialize_request(request: Request) -> bytes:
|
||||
method = request.method
|
||||
path = request.full_path.rstrip("?")
|
||||
raw = f"{method} {path} HTTP/1.1\r\n".encode()
|
||||
|
||||
for name, value in request.headers.items():
|
||||
raw += f"{name}: {value}\r\n".encode()
|
||||
|
||||
raw += b"\r\n"
|
||||
|
||||
body = request.get_data(as_text=False)
|
||||
if body:
|
||||
raw += body
|
||||
|
||||
return raw
|
||||
|
||||
|
||||
def deserialize_request(raw_data: bytes) -> Request:
|
||||
header_end = raw_data.find(b"\r\n\r\n")
|
||||
if header_end == -1:
|
||||
header_end = raw_data.find(b"\n\n")
|
||||
if header_end == -1:
|
||||
header_data = raw_data
|
||||
body = b""
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 2 :]
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 4 :]
|
||||
|
||||
lines = header_data.split(b"\r\n")
|
||||
if len(lines) == 1 and b"\n" in lines[0]:
|
||||
lines = header_data.split(b"\n")
|
||||
|
||||
if not lines or not lines[0]:
|
||||
raise ValueError("Empty HTTP request")
|
||||
|
||||
request_line = lines[0].decode("utf-8", errors="ignore")
|
||||
parts = request_line.split(" ", 2)
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid request line: {request_line}")
|
||||
|
||||
method = parts[0]
|
||||
full_path = parts[1]
|
||||
protocol = parts[2] if len(parts) > 2 else "HTTP/1.1"
|
||||
|
||||
if "?" in full_path:
|
||||
path, query_string = full_path.split("?", 1)
|
||||
else:
|
||||
path = full_path
|
||||
query_string = ""
|
||||
|
||||
headers = Headers()
|
||||
for line in lines[1:]:
|
||||
if not line:
|
||||
continue
|
||||
line_str = line.decode("utf-8", errors="ignore")
|
||||
if ":" not in line_str:
|
||||
continue
|
||||
name, value = line_str.split(":", 1)
|
||||
headers.add(name, value.strip())
|
||||
|
||||
host = headers.get("Host", "localhost")
|
||||
if ":" in host:
|
||||
server_name, server_port = host.rsplit(":", 1)
|
||||
else:
|
||||
server_name = host
|
||||
server_port = "80"
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": method,
|
||||
"PATH_INFO": path,
|
||||
"QUERY_STRING": query_string,
|
||||
"SERVER_NAME": server_name,
|
||||
"SERVER_PORT": server_port,
|
||||
"SERVER_PROTOCOL": protocol,
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "http",
|
||||
}
|
||||
|
||||
if "Content-Type" in headers:
|
||||
content_type = headers.get("Content-Type")
|
||||
if content_type is not None:
|
||||
environ["CONTENT_TYPE"] = content_type
|
||||
|
||||
if "Content-Length" in headers:
|
||||
content_length = headers.get("Content-Length")
|
||||
if content_length is not None:
|
||||
environ["CONTENT_LENGTH"] = content_length
|
||||
elif body:
|
||||
environ["CONTENT_LENGTH"] = str(len(body))
|
||||
|
||||
for name, value in headers.items():
|
||||
if name.upper() in ("CONTENT-TYPE", "CONTENT-LENGTH"):
|
||||
continue
|
||||
env_name = f"HTTP_{name.upper().replace('-', '_')}"
|
||||
environ[env_name] = value
|
||||
|
||||
return Request(environ)
|
||||
|
||||
|
||||
def serialize_response(response: Response) -> bytes:
|
||||
raw = f"HTTP/1.1 {response.status}\r\n".encode()
|
||||
|
||||
for name, value in response.headers.items():
|
||||
raw += f"{name}: {value}\r\n".encode()
|
||||
|
||||
raw += b"\r\n"
|
||||
|
||||
body = response.get_data(as_text=False)
|
||||
if body:
|
||||
raw += body
|
||||
|
||||
return raw
|
||||
|
||||
|
||||
def deserialize_response(raw_data: bytes) -> Response:
|
||||
header_end = raw_data.find(b"\r\n\r\n")
|
||||
if header_end == -1:
|
||||
header_end = raw_data.find(b"\n\n")
|
||||
if header_end == -1:
|
||||
header_data = raw_data
|
||||
body = b""
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 2 :]
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 4 :]
|
||||
|
||||
lines = header_data.split(b"\r\n")
|
||||
if len(lines) == 1 and b"\n" in lines[0]:
|
||||
lines = header_data.split(b"\n")
|
||||
|
||||
if not lines or not lines[0]:
|
||||
raise ValueError("Empty HTTP response")
|
||||
|
||||
status_line = lines[0].decode("utf-8", errors="ignore")
|
||||
parts = status_line.split(" ", 2)
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid status line: {status_line}")
|
||||
|
||||
status_code = int(parts[1])
|
||||
|
||||
response = Response(response=body, status=status_code)
|
||||
|
||||
for line in lines[1:]:
|
||||
if not line:
|
||||
continue
|
||||
line_str = line.decode("utf-8", errors="ignore")
|
||||
if ":" not in line_str:
|
||||
continue
|
||||
name, value = line_str.split(":", 1)
|
||||
response.headers[name] = value.strip()
|
||||
|
||||
return response
|
||||
Reference in New Issue
Block a user