dify
This commit is contained in:
0
dify/api/services/plugin/__init__.py
Normal file
0
dify/api/services/plugin/__init__.py
Normal file
212
dify/api/services/plugin/data_migration.py
Normal file
212
dify/api/services/plugin/data_migration.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.provider_ids import GenericProviderID, ModelProviderID, ToolProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginDataMigration:
|
||||
@classmethod
|
||||
def migrate(cls):
|
||||
cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table
|
||||
cls.migrate_db_records("provider_models", "provider_name", ModelProviderID)
|
||||
cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID)
|
||||
cls.migrate_db_records("tenant_default_models", "provider_name", ModelProviderID)
|
||||
cls.migrate_db_records("tenant_preferred_model_providers", "provider_name", ModelProviderID)
|
||||
cls.migrate_db_records("provider_model_settings", "provider_name", ModelProviderID)
|
||||
cls.migrate_db_records("load_balancing_model_configs", "provider_name", ModelProviderID)
|
||||
cls.migrate_datasets()
|
||||
cls.migrate_db_records("embeddings", "provider_name", ModelProviderID) # large table
|
||||
cls.migrate_db_records("dataset_collection_bindings", "provider_name", ModelProviderID)
|
||||
cls.migrate_db_records("tool_builtin_providers", "provider", ToolProviderID)
|
||||
|
||||
@classmethod
|
||||
def migrate_datasets(cls):
|
||||
table_name = "datasets"
|
||||
provider_column_name = "embedding_model_provider"
|
||||
|
||||
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
|
||||
|
||||
processed_count = 0
|
||||
failed_ids = []
|
||||
while True:
|
||||
sql = f"""select id, {provider_column_name} as provider_name, retrieval_model from {table_name}
|
||||
where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
|
||||
limit 1000"""
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql))
|
||||
|
||||
current_iter_count = 0
|
||||
for i in rs:
|
||||
record_id = str(i.id)
|
||||
provider_name = str(i.provider_name)
|
||||
retrieval_model = i.retrieval_model
|
||||
logger.debug(
|
||||
"Processing dataset %s with retrieval model of type %s",
|
||||
record_id,
|
||||
type(retrieval_model),
|
||||
)
|
||||
|
||||
if record_id in failed_ids:
|
||||
continue
|
||||
|
||||
retrieval_model_changed = False
|
||||
if retrieval_model:
|
||||
if (
|
||||
"reranking_model" in retrieval_model
|
||||
and "reranking_provider_name" in retrieval_model["reranking_model"]
|
||||
and retrieval_model["reranking_model"]["reranking_provider_name"]
|
||||
and "/" not in retrieval_model["reranking_model"]["reranking_provider_name"]
|
||||
):
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Migrating {table_name} {record_id} "
|
||||
f"(reranking_provider_name: "
|
||||
f"{retrieval_model['reranking_model']['reranking_provider_name']})",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
# update google to langgenius/gemini/google etc.
|
||||
retrieval_model["reranking_model"]["reranking_provider_name"] = ModelProviderID(
|
||||
retrieval_model["reranking_model"]["reranking_provider_name"]
|
||||
).to_string()
|
||||
retrieval_model_changed = True
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# update provider name append with "langgenius/{provider_name}/{provider_name}"
|
||||
params = {"record_id": record_id}
|
||||
update_retrieval_model_sql = ""
|
||||
if retrieval_model and retrieval_model_changed:
|
||||
update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
|
||||
params["retrieval_model"] = json.dumps(retrieval_model)
|
||||
|
||||
params["provider_name"] = ModelProviderID(provider_name).to_string()
|
||||
|
||||
sql = f"""update {table_name}
|
||||
set {provider_column_name} =
|
||||
:provider_name
|
||||
{update_retrieval_model_sql}
|
||||
where id = :record_id"""
|
||||
conn.execute(sa.text(sql), params)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
failed_ids.append(record_id)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
logger.exception(
|
||||
"[%s] Failed to migrate [%s] %s (%s)", processed_count, table_name, record_id, provider_name
|
||||
)
|
||||
continue
|
||||
|
||||
current_iter_count += 1
|
||||
processed_count += 1
|
||||
|
||||
if not current_iter_count:
|
||||
break
|
||||
|
||||
click.echo(
|
||||
click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def migrate_db_records(cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]):
|
||||
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
|
||||
|
||||
processed_count = 0
|
||||
failed_ids = []
|
||||
last_id = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
while True:
|
||||
sql = f"""
|
||||
SELECT id, {provider_column_name} AS provider_name
|
||||
FROM {table_name}
|
||||
WHERE {provider_column_name} NOT LIKE '%/%'
|
||||
AND {provider_column_name} IS NOT NULL
|
||||
AND {provider_column_name} != ''
|
||||
AND id > :last_id
|
||||
ORDER BY id ASC
|
||||
LIMIT 5000
|
||||
"""
|
||||
params = {"last_id": last_id or ""}
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql), params)
|
||||
|
||||
current_iter_count = 0
|
||||
batch_updates = []
|
||||
|
||||
for i in rs:
|
||||
current_iter_count += 1
|
||||
processed_count += 1
|
||||
record_id = str(i.id)
|
||||
last_id = record_id
|
||||
provider_name = str(i.provider_name)
|
||||
|
||||
if record_id in failed_ids:
|
||||
continue
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# update jina to langgenius/jina_tool/jina etc.
|
||||
updated_value = provider_cls(provider_name).to_string()
|
||||
batch_updates.append((updated_value, record_id))
|
||||
except Exception:
|
||||
failed_ids.append(record_id)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
logger.exception(
|
||||
"[%s] Failed to migrate [%s] %s (%s)", processed_count, table_name, record_id, provider_name
|
||||
)
|
||||
continue
|
||||
|
||||
if batch_updates:
|
||||
update_sql = f"""
|
||||
UPDATE {table_name}
|
||||
SET {provider_column_name} = :updated_value
|
||||
WHERE id = :record_id
|
||||
"""
|
||||
conn.execute(sa.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates])
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Batch migrated [{len(batch_updates)}] records from [{table_name}]",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
if not current_iter_count:
|
||||
break
|
||||
|
||||
click.echo(
|
||||
click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
|
||||
)
|
||||
137
dify/api/services/plugin/dependencies_analysis.py
Normal file
137
dify/api/services/plugin/dependencies_analysis.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import re
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import marketplace
|
||||
from core.plugin.entities.plugin import PluginDependency, PluginInstallationSource
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from models.provider_ids import ModelProviderID, ToolProviderID
|
||||
|
||||
# Compile regex pattern for version extraction at module level for better performance
|
||||
_VERSION_REGEX = re.compile(r":(?P<version>[0-9]+(?:\.[0-9]+){2}(?:[+-][0-9A-Za-z.-]+)?)(?:@|$)")
|
||||
|
||||
|
||||
class DependenciesAnalysisService:
|
||||
@classmethod
|
||||
def analyze_tool_dependency(cls, tool_id: str) -> str:
|
||||
"""
|
||||
Analyze the dependency of a tool.
|
||||
|
||||
Convert the tool id to the plugin_id
|
||||
"""
|
||||
try:
|
||||
return ToolProviderID(tool_id).plugin_id
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
def analyze_model_provider_dependency(cls, model_provider_id: str) -> str:
|
||||
"""
|
||||
Analyze the dependency of a model provider.
|
||||
|
||||
Convert the model provider id to the plugin_id
|
||||
"""
|
||||
try:
|
||||
return ModelProviderID(model_provider_id).plugin_id
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
def get_leaked_dependencies(cls, tenant_id: str, dependencies: list[PluginDependency]) -> list[PluginDependency]:
|
||||
"""
|
||||
Check dependencies, returns the leaked dependencies in current workspace
|
||||
"""
|
||||
required_plugin_unique_identifiers = []
|
||||
for dependency in dependencies:
|
||||
required_plugin_unique_identifiers.append(dependency.value.plugin_unique_identifier)
|
||||
|
||||
manager = PluginInstaller()
|
||||
|
||||
# get leaked dependencies
|
||||
missing_plugins = manager.fetch_missing_dependencies(tenant_id, required_plugin_unique_identifiers)
|
||||
missing_plugin_unique_identifiers = {plugin.plugin_unique_identifier: plugin for plugin in missing_plugins}
|
||||
|
||||
leaked_dependencies = []
|
||||
for dependency in dependencies:
|
||||
unique_identifier = dependency.value.plugin_unique_identifier
|
||||
if unique_identifier in missing_plugin_unique_identifiers:
|
||||
# Extract version for Marketplace dependencies
|
||||
if dependency.type == PluginDependency.Type.Marketplace:
|
||||
version_match = _VERSION_REGEX.search(unique_identifier)
|
||||
if version_match:
|
||||
dependency.value.version = version_match.group("version")
|
||||
|
||||
# Create and append the dependency (same for all types)
|
||||
leaked_dependencies.append(
|
||||
PluginDependency(
|
||||
type=dependency.type,
|
||||
value=dependency.value,
|
||||
current_identifier=missing_plugin_unique_identifiers[unique_identifier].current_identifier,
|
||||
)
|
||||
)
|
||||
|
||||
return leaked_dependencies
|
||||
|
||||
@classmethod
|
||||
def generate_dependencies(cls, tenant_id: str, dependencies: list[str]) -> list[PluginDependency]:
|
||||
"""
|
||||
Generate dependencies through the list of plugin ids
|
||||
"""
|
||||
dependencies = list(set(dependencies))
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.fetch_plugin_installation_by_ids(tenant_id, dependencies)
|
||||
result = []
|
||||
for plugin in plugins:
|
||||
if plugin.source == PluginInstallationSource.Github:
|
||||
result.append(
|
||||
PluginDependency(
|
||||
type=PluginDependency.Type.Github,
|
||||
value=PluginDependency.Github(
|
||||
repo=plugin.meta["repo"],
|
||||
version=plugin.meta["version"],
|
||||
package=plugin.meta["package"],
|
||||
github_plugin_unique_identifier=plugin.plugin_unique_identifier,
|
||||
),
|
||||
)
|
||||
)
|
||||
elif plugin.source == PluginInstallationSource.Marketplace:
|
||||
result.append(
|
||||
PluginDependency(
|
||||
type=PluginDependency.Type.Marketplace,
|
||||
value=PluginDependency.Marketplace(
|
||||
marketplace_plugin_unique_identifier=plugin.plugin_unique_identifier
|
||||
),
|
||||
)
|
||||
)
|
||||
elif plugin.source == PluginInstallationSource.Package:
|
||||
result.append(
|
||||
PluginDependency(
|
||||
type=PluginDependency.Type.Package,
|
||||
value=PluginDependency.Package(plugin_unique_identifier=plugin.plugin_unique_identifier),
|
||||
)
|
||||
)
|
||||
elif plugin.source == PluginInstallationSource.Remote:
|
||||
raise ValueError(
|
||||
f"You used a remote plugin: {plugin.plugin_unique_identifier} in the app, please remove it first"
|
||||
" if you want to export the DSL."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin source: {plugin.source}")
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def generate_latest_dependencies(cls, dependencies: list[str]) -> list[PluginDependency]:
|
||||
"""
|
||||
Generate the latest version of dependencies
|
||||
"""
|
||||
dependencies = list(set(dependencies))
|
||||
if not dify_config.MARKETPLACE_ENABLED:
|
||||
return []
|
||||
deps = marketplace.batch_fetch_plugin_manifests(dependencies)
|
||||
return [
|
||||
PluginDependency(
|
||||
type=PluginDependency.Type.Marketplace,
|
||||
value=PluginDependency.Marketplace(marketplace_plugin_unique_identifier=dep.latest_package_identifier),
|
||||
)
|
||||
for dep in deps
|
||||
]
|
||||
66
dify/api/services/plugin/endpoint_service.py
Normal file
66
dify/api/services/plugin/endpoint_service.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from core.plugin.impl.endpoint import PluginEndpointClient
|
||||
|
||||
|
||||
class EndpointService:
|
||||
@classmethod
|
||||
def create_endpoint(cls, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict):
|
||||
return PluginEndpointClient().create_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
name=name,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def list_endpoints(cls, tenant_id: str, user_id: str, page: int, page_size: int):
|
||||
return PluginEndpointClient().list_endpoints(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def list_endpoints_for_single_plugin(cls, tenant_id: str, user_id: str, plugin_id: str, page: int, page_size: int):
|
||||
return PluginEndpointClient().list_endpoints_for_single_plugin(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def update_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict):
|
||||
return PluginEndpointClient().update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=name,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def delete_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str):
|
||||
return PluginEndpointClient().delete_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
endpoint_id=endpoint_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def enable_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str):
|
||||
return PluginEndpointClient().enable_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
endpoint_id=endpoint_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def disable_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str):
|
||||
return PluginEndpointClient().disable_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
endpoint_id=endpoint_id,
|
||||
)
|
||||
65
dify/api/services/plugin/oauth_service.py
Normal file
65
dify/api/services/plugin/oauth_service.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class OAuthProxyService(BasePluginClient):
|
||||
# Default max age for proxy context parameter in seconds
|
||||
__MAX_AGE__ = 5 * 60 # 5 minutes
|
||||
__KEY_PREFIX__ = "oauth_proxy_context:"
|
||||
|
||||
@staticmethod
|
||||
def create_proxy_context(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
extra_data: dict = {},
|
||||
credential_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Create a proxy context for an OAuth 2.0 authorization request.
|
||||
|
||||
This parameter is a crucial security measure to prevent Cross-Site Request
|
||||
Forgery (CSRF) attacks. It works by generating a unique nonce and storing it
|
||||
in a distributed cache (Redis) along with the user's session context.
|
||||
|
||||
The returned nonce should be included as the 'proxy_context' parameter in the
|
||||
authorization URL. Upon callback, the `use_proxy_context` method
|
||||
is used to verify the state, ensuring the request's integrity and authenticity,
|
||||
and mitigating replay attacks.
|
||||
"""
|
||||
context_id = str(uuid.uuid4())
|
||||
data = {
|
||||
**extra_data,
|
||||
"user_id": user_id,
|
||||
"plugin_id": plugin_id,
|
||||
"tenant_id": tenant_id,
|
||||
"provider": provider,
|
||||
}
|
||||
if credential_id:
|
||||
data["credential_id"] = credential_id
|
||||
redis_client.setex(
|
||||
f"{OAuthProxyService.__KEY_PREFIX__}{context_id}",
|
||||
OAuthProxyService.__MAX_AGE__,
|
||||
json.dumps(data),
|
||||
)
|
||||
return context_id
|
||||
|
||||
@staticmethod
|
||||
def use_proxy_context(context_id: str):
|
||||
"""
|
||||
Validate the proxy context parameter.
|
||||
This checks if the context_id is valid and not expired.
|
||||
"""
|
||||
if not context_id:
|
||||
raise ValueError("context_id is required")
|
||||
# get data from redis
|
||||
key = f"{OAuthProxyService.__KEY_PREFIX__}{context_id}"
|
||||
data = redis_client.get(key)
|
||||
if not data:
|
||||
raise ValueError("context_id is invalid")
|
||||
redis_client.delete(key)
|
||||
return json.loads(data)
|
||||
87
dify/api/services/plugin/plugin_auto_upgrade_service.py
Normal file
87
dify/api/services/plugin/plugin_auto_upgrade_service.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
|
||||
|
||||
class PluginAutoUpgradeService:
|
||||
@staticmethod
|
||||
def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None:
|
||||
with Session(db.engine) as session:
|
||||
return (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def change_strategy(
|
||||
tenant_id: str,
|
||||
strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting,
|
||||
upgrade_time_of_day: int,
|
||||
upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode,
|
||||
exclude_plugins: list[str],
|
||||
include_plugins: list[str],
|
||||
) -> bool:
|
||||
with Session(db.engine) as session:
|
||||
exist_strategy = (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not exist_strategy:
|
||||
strategy = TenantPluginAutoUpgradeStrategy(
|
||||
tenant_id=tenant_id,
|
||||
strategy_setting=strategy_setting,
|
||||
upgrade_time_of_day=upgrade_time_of_day,
|
||||
upgrade_mode=upgrade_mode,
|
||||
exclude_plugins=exclude_plugins,
|
||||
include_plugins=include_plugins,
|
||||
)
|
||||
session.add(strategy)
|
||||
else:
|
||||
exist_strategy.strategy_setting = strategy_setting
|
||||
exist_strategy.upgrade_time_of_day = upgrade_time_of_day
|
||||
exist_strategy.upgrade_mode = upgrade_mode
|
||||
exist_strategy.exclude_plugins = exclude_plugins
|
||||
exist_strategy.include_plugins = include_plugins
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def exclude_plugin(tenant_id: str, plugin_id: str) -> bool:
|
||||
with Session(db.engine) as session:
|
||||
exist_strategy = (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not exist_strategy:
|
||||
# create for this tenant
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant_id,
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
0,
|
||||
TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
|
||||
[plugin_id],
|
||||
[],
|
||||
)
|
||||
return True
|
||||
else:
|
||||
if exist_strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE:
|
||||
if plugin_id not in exist_strategy.exclude_plugins:
|
||||
new_exclude_plugins = exist_strategy.exclude_plugins.copy()
|
||||
new_exclude_plugins.append(plugin_id)
|
||||
exist_strategy.exclude_plugins = new_exclude_plugins
|
||||
elif exist_strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL:
|
||||
if plugin_id in exist_strategy.include_plugins:
|
||||
new_include_plugins = exist_strategy.include_plugins.copy()
|
||||
new_include_plugins.remove(plugin_id)
|
||||
exist_strategy.include_plugins = new_include_plugins
|
||||
elif exist_strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL:
|
||||
exist_strategy.upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
|
||||
exist_strategy.exclude_plugins = [plugin_id]
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
595
dify/api/services/plugin/plugin_migration.py
Normal file
595
dify/api/services/plugin/plugin_migration.py
Normal file
@@ -0,0 +1,595 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
import tqdm
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.helper import marketplace
|
||||
from core.plugin.entities.plugin import PluginInstallationSource
|
||||
from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from models.provider_ids import ModelProviderID, ToolProviderID
|
||||
from models.tools import BuiltinToolProvider
|
||||
from models.workflow import Workflow
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
excluded_providers = ["time", "audio", "code", "webscraper"]
|
||||
|
||||
|
||||
class PluginMigration:
|
||||
@classmethod
|
||||
def extract_plugins(cls, filepath: str, workers: int):
|
||||
"""
|
||||
Migrate plugin.
|
||||
"""
|
||||
from threading import Lock
|
||||
|
||||
click.echo(click.style("Migrating models/tools to new plugin Mechanism", fg="white"))
|
||||
ended_at = datetime.datetime.now()
|
||||
started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
|
||||
current_time = started_at
|
||||
|
||||
with Session(db.engine) as session:
|
||||
total_tenant_count = session.query(Tenant.id).count()
|
||||
|
||||
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
|
||||
|
||||
handled_tenant_count = 0
|
||||
file_lock = Lock()
|
||||
counter_lock = Lock()
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=workers)
|
||||
|
||||
def process_tenant(flask_app: Flask, tenant_id: str):
|
||||
with flask_app.app_context():
|
||||
nonlocal handled_tenant_count
|
||||
try:
|
||||
plugins = cls.extract_installed_plugin_ids(tenant_id)
|
||||
# Use lock when writing to file
|
||||
with file_lock:
|
||||
with open(filepath, "a") as f:
|
||||
f.write(json.dumps({"tenant_id": tenant_id, "plugins": plugins}) + "\n")
|
||||
|
||||
# Use lock when updating counter
|
||||
with counter_lock:
|
||||
nonlocal handled_tenant_count
|
||||
handled_tenant_count += 1
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{datetime.datetime.now()}] "
|
||||
f"Processed {handled_tenant_count} tenants "
|
||||
f"({(handled_tenant_count / total_tenant_count) * 100:.1f}%), "
|
||||
f"{handled_tenant_count}/{total_tenant_count}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to process tenant %s", tenant_id)
|
||||
|
||||
futures = []
|
||||
|
||||
while current_time < ended_at:
|
||||
click.echo(click.style(f"Current time: {current_time}, Started at: {datetime.datetime.now()}", fg="white"))
|
||||
# Initial interval of 1 day, will be dynamically adjusted based on tenant count
|
||||
interval = datetime.timedelta(days=1)
|
||||
# Process tenants in this batch
|
||||
with Session(db.engine) as session:
|
||||
# Calculate tenant count in next batch with current interval
|
||||
# Try different intervals until we find one with a reasonable tenant count
|
||||
test_intervals = [
|
||||
datetime.timedelta(days=1),
|
||||
datetime.timedelta(hours=12),
|
||||
datetime.timedelta(hours=6),
|
||||
datetime.timedelta(hours=3),
|
||||
datetime.timedelta(hours=1),
|
||||
]
|
||||
|
||||
tenant_count = 0
|
||||
for test_interval in test_intervals:
|
||||
tenant_count = (
|
||||
session.query(Tenant.id)
|
||||
.where(Tenant.created_at.between(current_time, current_time + test_interval))
|
||||
.count()
|
||||
)
|
||||
if tenant_count <= 100:
|
||||
interval = test_interval
|
||||
break
|
||||
else:
|
||||
# If all intervals have too many tenants, use minimum interval
|
||||
interval = datetime.timedelta(hours=1)
|
||||
|
||||
# Adjust interval to target ~100 tenants per batch
|
||||
if tenant_count > 0:
|
||||
# Scale interval based on ratio to target count
|
||||
interval = min(
|
||||
datetime.timedelta(days=1), # Max 1 day
|
||||
max(
|
||||
datetime.timedelta(hours=1), # Min 1 hour
|
||||
interval * (100 / tenant_count), # Scale to target 100
|
||||
),
|
||||
)
|
||||
|
||||
batch_end = min(current_time + interval, ended_at)
|
||||
|
||||
rs = (
|
||||
session.query(Tenant.id)
|
||||
.where(Tenant.created_at.between(current_time, batch_end))
|
||||
.order_by(Tenant.created_at)
|
||||
)
|
||||
|
||||
tenants = []
|
||||
for row in rs:
|
||||
tenant_id = str(row.id)
|
||||
try:
|
||||
tenants.append(tenant_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to process tenant %s", tenant_id)
|
||||
continue
|
||||
|
||||
futures.append(
|
||||
thread_pool.submit(
|
||||
process_tenant,
|
||||
current_app._get_current_object(), # type: ignore
|
||||
tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
current_time = batch_end
|
||||
|
||||
# wait for all threads to finish
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
@classmethod
|
||||
def extract_installed_plugin_ids(cls, tenant_id: str) -> Sequence[str]:
|
||||
"""
|
||||
Extract installed plugin ids.
|
||||
"""
|
||||
tools = cls.extract_tool_tables(tenant_id)
|
||||
models = cls.extract_model_tables(tenant_id)
|
||||
workflows = cls.extract_workflow_tables(tenant_id)
|
||||
apps = cls.extract_app_tables(tenant_id)
|
||||
|
||||
return list({*tools, *models, *workflows, *apps})
|
||||
|
||||
@classmethod
|
||||
def extract_model_tables(cls, tenant_id: str) -> Sequence[str]:
|
||||
"""
|
||||
Extract model tables.
|
||||
|
||||
"""
|
||||
models: list[str] = []
|
||||
table_pairs = [
|
||||
("providers", "provider_name"),
|
||||
("provider_models", "provider_name"),
|
||||
("provider_orders", "provider_name"),
|
||||
("tenant_default_models", "provider_name"),
|
||||
("tenant_preferred_model_providers", "provider_name"),
|
||||
("provider_model_settings", "provider_name"),
|
||||
("load_balancing_model_configs", "provider_name"),
|
||||
]
|
||||
|
||||
for table, column in table_pairs:
|
||||
models.extend(cls.extract_model_table(tenant_id, table, column))
|
||||
|
||||
# duplicate models
|
||||
models = list(set(models))
|
||||
|
||||
return models
|
||||
|
||||
@classmethod
|
||||
def extract_model_table(cls, tenant_id: str, table: str, column: str) -> Sequence[str]:
|
||||
"""
|
||||
Extract model table.
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
rs = session.execute(
|
||||
sa.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id}
|
||||
)
|
||||
result = []
|
||||
for row in rs:
|
||||
provider_name = str(row[0])
|
||||
result.append(ModelProviderID(provider_name).plugin_id)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def extract_tool_tables(cls, tenant_id: str) -> Sequence[str]:
|
||||
"""
|
||||
Extract tool tables.
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
rs = session.query(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant_id).all()
|
||||
result = []
|
||||
for row in rs:
|
||||
result.append(ToolProviderID(row.provider).plugin_id)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def extract_workflow_tables(cls, tenant_id: str) -> Sequence[str]:
|
||||
"""
|
||||
Extract workflow tables, only ToolNode is required.
|
||||
"""
|
||||
|
||||
with Session(db.engine) as session:
|
||||
rs = session.query(Workflow).where(Workflow.tenant_id == tenant_id).all()
|
||||
result = []
|
||||
for row in rs:
|
||||
graph = row.graph_dict
|
||||
# get nodes
|
||||
nodes = graph.get("nodes", [])
|
||||
|
||||
for node in nodes:
|
||||
data = node.get("data", {})
|
||||
if data.get("type") == "tool":
|
||||
provider_name = data.get("provider_name")
|
||||
provider_type = data.get("provider_type")
|
||||
if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN:
|
||||
result.append(ToolProviderID(provider_name).plugin_id)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def extract_app_tables(cls, tenant_id: str) -> Sequence[str]:
|
||||
"""
|
||||
Extract app tables.
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
apps = session.query(App).where(App.tenant_id == tenant_id).all()
|
||||
if not apps:
|
||||
return []
|
||||
|
||||
agent_app_model_config_ids = [
|
||||
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT
|
||||
]
|
||||
|
||||
rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
|
||||
result = []
|
||||
for row in rs:
|
||||
agent_config = row.agent_mode_dict
|
||||
if "tools" in agent_config and isinstance(agent_config["tools"], list):
|
||||
for tool in agent_config["tools"]:
|
||||
if isinstance(tool, dict):
|
||||
try:
|
||||
tool_entity = AgentToolEntity.model_validate(tool)
|
||||
if (
|
||||
tool_entity.provider_type == ToolProviderType.BUILT_IN
|
||||
and tool_entity.provider_id not in excluded_providers
|
||||
):
|
||||
result.append(ToolProviderID(tool_entity.provider_id).plugin_id)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to process tool %s", tool)
|
||||
continue
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> str | None:
|
||||
"""
|
||||
Fetch plugin unique identifier using plugin id.
|
||||
"""
|
||||
plugin_manifest = marketplace.batch_fetch_plugin_manifests([plugin_id])
|
||||
if not plugin_manifest:
|
||||
return None
|
||||
|
||||
return plugin_manifest[0].latest_package_identifier
|
||||
|
||||
@classmethod
|
||||
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str):
|
||||
"""
|
||||
Extract unique plugins.
|
||||
"""
|
||||
Path(output_file).write_text(json.dumps(cls.extract_unique_plugins(extracted_plugins)))
|
||||
|
||||
@classmethod
|
||||
def extract_unique_plugins(cls, extracted_plugins: str) -> Mapping[str, Any]:
|
||||
plugins: dict[str, str] = {}
|
||||
plugin_ids = []
|
||||
plugin_not_exist = []
|
||||
logger.info("Extracting unique plugins from %s", extracted_plugins)
|
||||
with open(extracted_plugins) as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
new_plugin_ids = data.get("plugins", [])
|
||||
for plugin_id in new_plugin_ids:
|
||||
if plugin_id not in plugin_ids:
|
||||
plugin_ids.append(plugin_id)
|
||||
|
||||
def fetch_plugin(plugin_id):
|
||||
try:
|
||||
unique_identifier = cls._fetch_plugin_unique_identifier(plugin_id)
|
||||
if unique_identifier:
|
||||
plugins[plugin_id] = unique_identifier
|
||||
else:
|
||||
plugin_not_exist.append(plugin_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch plugin unique identifier for %s", plugin_id)
|
||||
plugin_not_exist.append(plugin_id)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
list(tqdm.tqdm(executor.map(fetch_plugin, plugin_ids), total=len(plugin_ids)))
|
||||
|
||||
return {"plugins": plugins, "plugin_not_exist": plugin_not_exist}
|
||||
|
||||
@classmethod
|
||||
def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100):
|
||||
"""
|
||||
Install plugins.
|
||||
"""
|
||||
manager = PluginInstaller()
|
||||
|
||||
plugins = cls.extract_unique_plugins(extracted_plugins)
|
||||
not_installed = []
|
||||
plugin_install_failed = []
|
||||
|
||||
# use a fake tenant id to install all the plugins
|
||||
fake_tenant_id = uuid4().hex
|
||||
logger.info("Installing %s plugin instances for fake tenant %s", len(plugins["plugins"]), fake_tenant_id)
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=workers)
|
||||
|
||||
response = cls.handle_plugin_instance_install(fake_tenant_id, plugins["plugins"])
|
||||
if response.get("failed"):
|
||||
plugin_install_failed.extend(response.get("failed", []))
|
||||
|
||||
def install(tenant_id: str, plugin_ids: list[str]):
|
||||
logger.info("Installing %s plugins for tenant %s", len(plugin_ids), tenant_id)
|
||||
# fetch plugin already installed
|
||||
installed_plugins = manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
# at most 64 plugins one batch
|
||||
for i in range(0, len(plugin_ids), 64):
|
||||
batch_plugin_ids = plugin_ids[i : i + 64]
|
||||
batch_plugin_identifiers = [
|
||||
plugins["plugins"][plugin_id]
|
||||
for plugin_id in batch_plugin_ids
|
||||
if plugin_id not in installed_plugins_ids and plugin_id in plugins["plugins"]
|
||||
]
|
||||
manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
batch_plugin_identifiers,
|
||||
PluginInstallationSource.Marketplace,
|
||||
metas=[
|
||||
{
|
||||
"plugin_unique_identifier": identifier,
|
||||
}
|
||||
for identifier in batch_plugin_identifiers
|
||||
],
|
||||
)
|
||||
|
||||
with open(extracted_plugins) as f:
|
||||
"""
|
||||
Read line by line, and install plugins for each tenant.
|
||||
"""
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
tenant_id = data.get("tenant_id")
|
||||
plugin_ids = data.get("plugins", [])
|
||||
current_not_installed = {
|
||||
"tenant_id": tenant_id,
|
||||
"plugin_not_exist": [],
|
||||
}
|
||||
# get plugin unique identifier
|
||||
for plugin_id in plugin_ids:
|
||||
unique_identifier = plugins.get(plugin_id)
|
||||
if unique_identifier:
|
||||
current_not_installed["plugin_not_exist"].append(plugin_id)
|
||||
|
||||
if current_not_installed["plugin_not_exist"]:
|
||||
not_installed.append(current_not_installed)
|
||||
|
||||
thread_pool.submit(install, tenant_id, plugin_ids)
|
||||
|
||||
thread_pool.shutdown(wait=True)
|
||||
|
||||
logger.info("Uninstall plugins")
|
||||
|
||||
# get installation
|
||||
try:
|
||||
installation = manager.list_plugins(fake_tenant_id)
|
||||
while installation:
|
||||
for plugin in installation:
|
||||
manager.uninstall(fake_tenant_id, plugin.installation_id)
|
||||
|
||||
installation = manager.list_plugins(fake_tenant_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to get installation for tenant %s", fake_tenant_id)
|
||||
|
||||
Path(output_file).write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"not_installed": not_installed,
|
||||
"plugin_install_failed": plugin_install_failed,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def install_rag_pipeline_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100) -> None:
|
||||
"""
|
||||
Install rag pipeline plugins.
|
||||
"""
|
||||
manager = PluginInstaller()
|
||||
|
||||
plugins = cls.extract_unique_plugins(extracted_plugins)
|
||||
plugin_install_failed = []
|
||||
|
||||
# use a fake tenant id to install all the plugins
|
||||
fake_tenant_id = uuid4().hex
|
||||
logger.info("Installing %s plugin instances for fake tenant %s", len(plugins["plugins"]), fake_tenant_id)
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=workers)
|
||||
|
||||
response = cls.handle_plugin_instance_install(fake_tenant_id, plugins["plugins"])
|
||||
if response.get("failed"):
|
||||
plugin_install_failed.extend(response.get("failed", []))
|
||||
|
||||
def install(
|
||||
tenant_id: str, plugin_ids: dict[str, str], total_success_tenant: int, total_failed_tenant: int
|
||||
) -> None:
|
||||
logger.info("Installing %s plugins for tenant %s", len(plugin_ids), tenant_id)
|
||||
try:
|
||||
# fetch plugin already installed
|
||||
installed_plugins = manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
# at most 64 plugins one batch
|
||||
for i in range(0, len(plugin_ids), 64):
|
||||
batch_plugin_ids = list(plugin_ids.keys())[i : i + 64]
|
||||
batch_plugin_identifiers = [
|
||||
plugin_ids[plugin_id]
|
||||
for plugin_id in batch_plugin_ids
|
||||
if plugin_id not in installed_plugins_ids and plugin_id in plugin_ids
|
||||
]
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, batch_plugin_identifiers)
|
||||
|
||||
total_success_tenant += 1
|
||||
except Exception:
|
||||
logger.exception("Failed to install plugins for tenant %s", tenant_id)
|
||||
total_failed_tenant += 1
|
||||
|
||||
page = 1
|
||||
total_success_tenant = 0
|
||||
total_failed_tenant = 0
|
||||
while True:
|
||||
# paginate
|
||||
tenants = db.paginate(sa.select(Tenant).order_by(Tenant.created_at.desc()), page=page, per_page=100)
|
||||
if tenants.items is None or len(tenants.items) == 0:
|
||||
break
|
||||
|
||||
for tenant in tenants:
|
||||
tenant_id = tenant.id
|
||||
# get plugin unique identifier
|
||||
thread_pool.submit(
|
||||
install,
|
||||
tenant_id,
|
||||
plugins.get("plugins", {}),
|
||||
total_success_tenant,
|
||||
total_failed_tenant,
|
||||
)
|
||||
|
||||
page += 1
|
||||
|
||||
thread_pool.shutdown(wait=True)
|
||||
|
||||
# uninstall all the plugins for fake tenant
|
||||
try:
|
||||
installation = manager.list_plugins(fake_tenant_id)
|
||||
while installation:
|
||||
for plugin in installation:
|
||||
manager.uninstall(fake_tenant_id, plugin.installation_id)
|
||||
|
||||
installation = manager.list_plugins(fake_tenant_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to get installation for tenant %s", fake_tenant_id)
|
||||
|
||||
Path(output_file).write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"total_success_tenant": total_success_tenant,
|
||||
"total_failed_tenant": total_failed_tenant,
|
||||
"plugin_install_failed": plugin_install_failed,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def handle_plugin_instance_install(
|
||||
cls, tenant_id: str, plugin_identifiers_map: Mapping[str, str]
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Install plugins for a tenant.
|
||||
"""
|
||||
manager = PluginInstaller()
|
||||
|
||||
# download all the plugins and upload
|
||||
thread_pool = ThreadPoolExecutor(max_workers=10)
|
||||
futures = []
|
||||
|
||||
for plugin_id, plugin_identifier in plugin_identifiers_map.items():
|
||||
|
||||
def download_and_upload(tenant_id, plugin_id, plugin_identifier):
|
||||
plugin_package = marketplace.download_plugin_pkg(plugin_identifier)
|
||||
if not plugin_package:
|
||||
raise Exception(f"Failed to download plugin {plugin_identifier}")
|
||||
|
||||
# upload
|
||||
manager.upload_pkg(tenant_id, plugin_package, verify_signature=True)
|
||||
|
||||
futures.append(thread_pool.submit(download_and_upload, tenant_id, plugin_id, plugin_identifier))
|
||||
|
||||
# Wait for all downloads to complete
|
||||
for future in futures:
|
||||
future.result() # This will raise any exceptions that occurred
|
||||
|
||||
thread_pool.shutdown(wait=True)
|
||||
success = []
|
||||
failed = []
|
||||
|
||||
reverse_map = {v: k for k, v in plugin_identifiers_map.items()}
|
||||
|
||||
# at most 8 plugins one batch
|
||||
for i in range(0, len(plugin_identifiers_map), 8):
|
||||
batch_plugin_ids = list(plugin_identifiers_map.keys())[i : i + 8]
|
||||
batch_plugin_identifiers = [plugin_identifiers_map[plugin_id] for plugin_id in batch_plugin_ids]
|
||||
|
||||
try:
|
||||
response = manager.install_from_identifiers(
|
||||
tenant_id=tenant_id,
|
||||
identifiers=batch_plugin_identifiers,
|
||||
source=PluginInstallationSource.Marketplace,
|
||||
metas=[
|
||||
{
|
||||
"plugin_unique_identifier": identifier,
|
||||
}
|
||||
for identifier in batch_plugin_identifiers
|
||||
],
|
||||
)
|
||||
except Exception:
|
||||
# add to failed
|
||||
failed.extend(batch_plugin_identifiers)
|
||||
continue
|
||||
|
||||
if response.all_installed:
|
||||
success.extend(batch_plugin_identifiers)
|
||||
continue
|
||||
|
||||
task_id = response.task_id
|
||||
done = False
|
||||
while not done:
|
||||
status = manager.fetch_plugin_installation_task(tenant_id, task_id)
|
||||
if status.status in [PluginInstallTaskStatus.Failed, PluginInstallTaskStatus.Success]:
|
||||
for plugin in status.plugins:
|
||||
if plugin.status == PluginInstallTaskStatus.Success:
|
||||
success.append(reverse_map[plugin.plugin_unique_identifier])
|
||||
else:
|
||||
failed.append(reverse_map[plugin.plugin_unique_identifier])
|
||||
logger.error(
|
||||
"Failed to install plugin %s, error: %s",
|
||||
plugin.plugin_unique_identifier,
|
||||
plugin.message,
|
||||
)
|
||||
|
||||
done = True
|
||||
else:
|
||||
time.sleep(1)
|
||||
|
||||
return {"success": success, "failed": failed}
|
||||
107
dify/api/services/plugin/plugin_parameter_service.py
Normal file
107
dify/api/services/plugin/plugin_parameter_service.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.dynamic_select import DynamicSelectClient
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.encryption import create_tool_provider_encrypter
|
||||
from core.trigger.entities.api_entities import TriggerProviderSubscriptionApiEntity
|
||||
from core.trigger.entities.entities import SubscriptionBuilder
|
||||
from extensions.ext_database import db
|
||||
from models.tools import BuiltinToolProvider
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||
|
||||
|
||||
class PluginParameterService:
|
||||
@staticmethod
|
||||
def get_dynamic_select_options(
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
action: str,
|
||||
parameter: str,
|
||||
credential_id: str | None,
|
||||
provider_type: Literal["tool", "trigger"],
|
||||
) -> Sequence[PluginParameterOption]:
|
||||
"""
|
||||
Get dynamic select options for a plugin parameter.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID.
|
||||
plugin_id: The plugin ID.
|
||||
provider: The provider name.
|
||||
action: The action name.
|
||||
parameter: The parameter name.
|
||||
"""
|
||||
credentials: Mapping[str, Any] = {}
|
||||
credential_type: str = CredentialType.UNAUTHORIZED.value
|
||||
match provider_type:
|
||||
case "tool":
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
# init tool configuration
|
||||
encrypter, _ = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
)
|
||||
|
||||
# check if credentials are required
|
||||
if not provider_controller.need_credentials:
|
||||
credentials = {}
|
||||
else:
|
||||
# fetch credentials from db
|
||||
with Session(db.engine) as session:
|
||||
if credential_id:
|
||||
db_record = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
BuiltinToolProvider.id == credential_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
db_record = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if db_record is None:
|
||||
raise ValueError(f"Builtin provider {provider} not found when fetching credentials")
|
||||
|
||||
credentials = encrypter.decrypt(db_record.credentials)
|
||||
credential_type = db_record.credential_type
|
||||
case "trigger":
|
||||
subscription: TriggerProviderSubscriptionApiEntity | SubscriptionBuilder | None
|
||||
if credential_id:
|
||||
subscription = TriggerSubscriptionBuilderService.get_subscription_builder(credential_id)
|
||||
if not subscription:
|
||||
trigger_subscription = TriggerProviderService.get_subscription_by_id(tenant_id, credential_id)
|
||||
subscription = trigger_subscription.to_api_entity() if trigger_subscription else None
|
||||
else:
|
||||
trigger_subscription = TriggerProviderService.get_subscription_by_id(tenant_id)
|
||||
subscription = trigger_subscription.to_api_entity() if trigger_subscription else None
|
||||
|
||||
if subscription is None:
|
||||
raise ValueError(f"Subscription {credential_id} not found")
|
||||
|
||||
credentials = subscription.credentials
|
||||
credential_type = subscription.credential_type or CredentialType.UNAUTHORIZED
|
||||
|
||||
return (
|
||||
DynamicSelectClient()
|
||||
.fetch_dynamic_select_options(
|
||||
tenant_id, user_id, plugin_id, provider, action, credentials, credential_type, parameter
|
||||
)
|
||||
.options
|
||||
)
|
||||
34
dify/api/services/plugin/plugin_permission_service.py
Normal file
34
dify/api/services/plugin/plugin_permission_service.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
|
||||
class PluginPermissionService:
|
||||
@staticmethod
|
||||
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
|
||||
with Session(db.engine) as session:
|
||||
return session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
|
||||
|
||||
@staticmethod
|
||||
def change_permission(
|
||||
tenant_id: str,
|
||||
install_permission: TenantPluginPermission.InstallPermission,
|
||||
debug_permission: TenantPluginPermission.DebugPermission,
|
||||
):
|
||||
with Session(db.engine) as session:
|
||||
permission = (
|
||||
session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
|
||||
)
|
||||
if not permission:
|
||||
permission = TenantPluginPermission(
|
||||
tenant_id=tenant_id, install_permission=install_permission, debug_permission=debug_permission
|
||||
)
|
||||
|
||||
session.add(permission)
|
||||
else:
|
||||
permission.install_permission = install_permission
|
||||
permission.debug_permission = debug_permission
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
525
dify/api/services/plugin/plugin_service.py
Normal file
525
dify/api/services/plugin/plugin_service.py
Normal file
@@ -0,0 +1,525 @@
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from mimetypes import guess_type
|
||||
|
||||
from pydantic import BaseModel
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import marketplace
|
||||
from core.helper.download import download_with_size_limit
|
||||
from core.helper.marketplace import download_plugin_pkg
|
||||
from core.plugin.entities.bundle import PluginBundleDependency
|
||||
from core.plugin.entities.plugin import (
|
||||
PluginDeclaration,
|
||||
PluginEntity,
|
||||
PluginInstallation,
|
||||
PluginInstallationSource,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginDecodeResponse,
|
||||
PluginInstallTask,
|
||||
PluginListResponse,
|
||||
PluginVerification,
|
||||
)
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
from core.plugin.impl.debugging import PluginDebuggingClient
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import GenericProviderID
|
||||
from services.errors.plugin import PluginInstallationForbiddenError
|
||||
from services.feature_service import FeatureService, PluginInstallationScope
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginService:
|
||||
class LatestPluginCache(BaseModel):
|
||||
plugin_id: str
|
||||
version: str
|
||||
unique_identifier: str
|
||||
status: str
|
||||
deprecated_reason: str
|
||||
alternative_plugin_id: str
|
||||
|
||||
REDIS_KEY_PREFIX = "plugin_service:latest_plugin:"
|
||||
REDIS_TTL = 60 * 5 # 5 minutes
|
||||
|
||||
@staticmethod
|
||||
def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]:
|
||||
"""
|
||||
Fetch the latest plugin version
|
||||
"""
|
||||
result: dict[str, PluginService.LatestPluginCache | None] = {}
|
||||
|
||||
try:
|
||||
cache_not_exists = []
|
||||
|
||||
# Try to get from Redis first
|
||||
for plugin_id in plugin_ids:
|
||||
cached_data = redis_client.get(f"{PluginService.REDIS_KEY_PREFIX}{plugin_id}")
|
||||
if cached_data:
|
||||
result[plugin_id] = PluginService.LatestPluginCache.model_validate_json(cached_data)
|
||||
else:
|
||||
cache_not_exists.append(plugin_id)
|
||||
|
||||
if cache_not_exists:
|
||||
manifests = {
|
||||
manifest.plugin_id: manifest
|
||||
for manifest in marketplace.batch_fetch_plugin_manifests(cache_not_exists)
|
||||
}
|
||||
|
||||
for plugin_id, manifest in manifests.items():
|
||||
latest_plugin = PluginService.LatestPluginCache(
|
||||
plugin_id=plugin_id,
|
||||
version=manifest.latest_version,
|
||||
unique_identifier=manifest.latest_package_identifier,
|
||||
status=manifest.status,
|
||||
deprecated_reason=manifest.deprecated_reason,
|
||||
alternative_plugin_id=manifest.alternative_plugin_id,
|
||||
)
|
||||
|
||||
# Store in Redis
|
||||
redis_client.setex(
|
||||
f"{PluginService.REDIS_KEY_PREFIX}{plugin_id}",
|
||||
PluginService.REDIS_TTL,
|
||||
latest_plugin.model_dump_json(),
|
||||
)
|
||||
|
||||
result[plugin_id] = latest_plugin
|
||||
|
||||
# pop plugin_id from cache_not_exists
|
||||
cache_not_exists.remove(plugin_id)
|
||||
|
||||
for plugin_id in cache_not_exists:
|
||||
result[plugin_id] = None
|
||||
|
||||
return result
|
||||
except Exception:
|
||||
logger.exception("failed to fetch latest plugin version")
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _check_marketplace_only_permission():
|
||||
"""
|
||||
Check if the marketplace only permission is enabled
|
||||
"""
|
||||
features = FeatureService.get_system_features()
|
||||
if features.plugin_installation_permission.restrict_to_marketplace_only:
|
||||
raise PluginInstallationForbiddenError("Plugin installation is restricted to marketplace only")
|
||||
|
||||
@staticmethod
|
||||
def _check_plugin_installation_scope(plugin_verification: PluginVerification | None):
|
||||
"""
|
||||
Check the plugin installation scope
|
||||
"""
|
||||
features = FeatureService.get_system_features()
|
||||
|
||||
match features.plugin_installation_permission.plugin_installation_scope:
|
||||
case PluginInstallationScope.OFFICIAL_ONLY:
|
||||
if (
|
||||
plugin_verification is None
|
||||
or plugin_verification.authorized_category != PluginVerification.AuthorizedCategory.Langgenius
|
||||
):
|
||||
raise PluginInstallationForbiddenError("Plugin installation is restricted to official only")
|
||||
case PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS:
|
||||
if plugin_verification is None or plugin_verification.authorized_category not in [
|
||||
PluginVerification.AuthorizedCategory.Langgenius,
|
||||
PluginVerification.AuthorizedCategory.Partner,
|
||||
]:
|
||||
raise PluginInstallationForbiddenError(
|
||||
"Plugin installation is restricted to official and specific partners"
|
||||
)
|
||||
case PluginInstallationScope.NONE:
|
||||
raise PluginInstallationForbiddenError("Installing plugins is not allowed")
|
||||
case PluginInstallationScope.ALL:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_debugging_key(tenant_id: str) -> str:
|
||||
"""
|
||||
get the debugging key of the tenant
|
||||
"""
|
||||
manager = PluginDebuggingClient()
|
||||
return manager.get_debugging_key(tenant_id)
|
||||
|
||||
@staticmethod
|
||||
def list_latest_versions(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]:
|
||||
"""
|
||||
List the latest versions of the plugins
|
||||
"""
|
||||
return PluginService.fetch_latest_plugin_version(plugin_ids)
|
||||
|
||||
@staticmethod
|
||||
def list(tenant_id: str) -> list[PluginEntity]:
|
||||
"""
|
||||
list all plugins of the tenant
|
||||
"""
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
return plugins
|
||||
|
||||
@staticmethod
|
||||
def list_with_total(tenant_id: str, page: int, page_size: int) -> PluginListResponse:
|
||||
"""
|
||||
list all plugins of the tenant
|
||||
"""
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins_with_total(tenant_id, page, page_size)
|
||||
return plugins
|
||||
|
||||
@staticmethod
|
||||
def list_installations_from_ids(tenant_id: str, ids: Sequence[str]) -> Sequence[PluginInstallation]:
|
||||
"""
|
||||
List plugin installations from ids
|
||||
"""
|
||||
manager = PluginInstaller()
|
||||
return manager.fetch_plugin_installation_by_ids(tenant_id, ids)
|
||||
|
||||
@classmethod
|
||||
def get_plugin_icon_url(cls, tenant_id: str, filename: str) -> str:
|
||||
url_prefix = (
|
||||
URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "plugin" / "icon"
|
||||
)
|
||||
return str(url_prefix % {"tenant_id": tenant_id, "filename": filename})
|
||||
|
||||
@staticmethod
|
||||
def get_asset(tenant_id: str, asset_file: str) -> tuple[bytes, str]:
|
||||
"""
|
||||
get the asset file of the plugin
|
||||
"""
|
||||
manager = PluginAssetManager()
|
||||
# guess mime type
|
||||
mime_type, _ = guess_type(asset_file)
|
||||
return manager.fetch_asset(tenant_id, asset_file), mime_type or "application/octet-stream"
|
||||
|
||||
@staticmethod
|
||||
def extract_asset(tenant_id: str, plugin_unique_identifier: str, file_name: str) -> bytes:
|
||||
manager = PluginAssetManager()
|
||||
return manager.extract_asset(tenant_id, plugin_unique_identifier, file_name)
|
||||
|
||||
@staticmethod
|
||||
def check_plugin_unique_identifier(tenant_id: str, plugin_unique_identifier: str) -> bool:
|
||||
"""
|
||||
check if the plugin unique identifier is already installed by other tenant
|
||||
"""
|
||||
manager = PluginInstaller()
|
||||
return manager.fetch_plugin_by_identifier(tenant_id, plugin_unique_identifier)
|
||||
|
||||
@staticmethod
|
||||
def fetch_plugin_manifest(tenant_id: str, plugin_unique_identifier: str) -> PluginDeclaration:
|
||||
"""
|
||||
Fetch plugin manifest
|
||||
"""
|
||||
manager = PluginInstaller()
|
||||
return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
|
||||
|
||||
@staticmethod
|
||||
def is_plugin_verified(tenant_id: str, plugin_unique_identifier: str) -> bool:
|
||||
"""
|
||||
Check if the plugin is verified
|
||||
"""
|
||||
manager = PluginInstaller()
|
||||
try:
|
||||
return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier).verified
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def fetch_install_tasks(tenant_id: str, page: int, page_size: int) -> Sequence[PluginInstallTask]:
|
||||
"""
|
||||
Fetch plugin installation tasks
|
||||
"""
|
||||
manager = PluginInstaller()
|
||||
return manager.fetch_plugin_installation_tasks(tenant_id, page, page_size)
|
||||
|
||||
@staticmethod
|
||||
def fetch_install_task(tenant_id: str, task_id: str) -> PluginInstallTask:
|
||||
manager = PluginInstaller()
|
||||
return manager.fetch_plugin_installation_task(tenant_id, task_id)
|
||||
|
||||
@staticmethod
|
||||
def delete_install_task(tenant_id: str, task_id: str) -> bool:
|
||||
"""
|
||||
Delete a plugin installation task
|
||||
"""
|
||||
manager = PluginInstaller()
|
||||
return manager.delete_plugin_installation_task(tenant_id, task_id)
|
||||
|
||||
@staticmethod
|
||||
def delete_all_install_task_items(
|
||||
tenant_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete all plugin installation task items
|
||||
"""
|
||||
manager = PluginInstaller()
|
||||
return manager.delete_all_plugin_installation_task_items(tenant_id)
|
||||
|
||||
@staticmethod
|
||||
def delete_install_task_item(tenant_id: str, task_id: str, identifier: str) -> bool:
|
||||
"""
|
||||
Delete a plugin installation task item
|
||||
"""
|
||||
manager = PluginInstaller()
|
||||
return manager.delete_plugin_installation_task_item(tenant_id, task_id, identifier)
|
||||
|
||||
@staticmethod
|
||||
def upgrade_plugin_with_marketplace(
|
||||
tenant_id: str, original_plugin_unique_identifier: str, new_plugin_unique_identifier: str
|
||||
):
|
||||
"""
|
||||
Upgrade plugin with marketplace
|
||||
"""
|
||||
if not dify_config.MARKETPLACE_ENABLED:
|
||||
raise ValueError("marketplace is not enabled")
|
||||
|
||||
if original_plugin_unique_identifier == new_plugin_unique_identifier:
|
||||
raise ValueError("you should not upgrade plugin with the same plugin")
|
||||
|
||||
# check if plugin pkg is already downloaded
|
||||
manager = PluginInstaller()
|
||||
|
||||
features = FeatureService.get_system_features()
|
||||
|
||||
try:
|
||||
manager.fetch_plugin_manifest(tenant_id, new_plugin_unique_identifier)
|
||||
# already downloaded, skip, and record install event
|
||||
marketplace.record_install_plugin_event(new_plugin_unique_identifier)
|
||||
except Exception:
|
||||
# plugin not installed, download and upload pkg
|
||||
pkg = download_plugin_pkg(new_plugin_unique_identifier)
|
||||
response = manager.upload_pkg(
|
||||
tenant_id,
|
||||
pkg,
|
||||
verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only,
|
||||
)
|
||||
|
||||
# check if the plugin is available to install
|
||||
PluginService._check_plugin_installation_scope(response.verification)
|
||||
|
||||
return manager.upgrade_plugin(
|
||||
tenant_id,
|
||||
original_plugin_unique_identifier,
|
||||
new_plugin_unique_identifier,
|
||||
PluginInstallationSource.Marketplace,
|
||||
{
|
||||
"plugin_unique_identifier": new_plugin_unique_identifier,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def upgrade_plugin_with_github(
|
||||
tenant_id: str,
|
||||
original_plugin_unique_identifier: str,
|
||||
new_plugin_unique_identifier: str,
|
||||
repo: str,
|
||||
version: str,
|
||||
package: str,
|
||||
):
|
||||
"""
|
||||
Upgrade plugin with github
|
||||
"""
|
||||
PluginService._check_marketplace_only_permission()
|
||||
manager = PluginInstaller()
|
||||
return manager.upgrade_plugin(
|
||||
tenant_id,
|
||||
original_plugin_unique_identifier,
|
||||
new_plugin_unique_identifier,
|
||||
PluginInstallationSource.Github,
|
||||
{
|
||||
"repo": repo,
|
||||
"version": version,
|
||||
"package": package,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def upload_pkg(tenant_id: str, pkg: bytes, verify_signature: bool = False) -> PluginDecodeResponse:
|
||||
"""
|
||||
Upload plugin package files
|
||||
|
||||
returns: plugin_unique_identifier
|
||||
"""
|
||||
PluginService._check_marketplace_only_permission()
|
||||
manager = PluginInstaller()
|
||||
features = FeatureService.get_system_features()
|
||||
response = manager.upload_pkg(
|
||||
tenant_id,
|
||||
pkg,
|
||||
verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only,
|
||||
)
|
||||
PluginService._check_plugin_installation_scope(response.verification)
|
||||
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def upload_pkg_from_github(
|
||||
tenant_id: str, repo: str, version: str, package: str, verify_signature: bool = False
|
||||
) -> PluginDecodeResponse:
|
||||
"""
|
||||
Install plugin from github release package files,
|
||||
returns plugin_unique_identifier
|
||||
"""
|
||||
PluginService._check_marketplace_only_permission()
|
||||
pkg = download_with_size_limit(
|
||||
f"https://github.com/{repo}/releases/download/{version}/{package}", dify_config.PLUGIN_MAX_PACKAGE_SIZE
|
||||
)
|
||||
features = FeatureService.get_system_features()
|
||||
|
||||
manager = PluginInstaller()
|
||||
response = manager.upload_pkg(
|
||||
tenant_id,
|
||||
pkg,
|
||||
verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only,
|
||||
)
|
||||
PluginService._check_plugin_installation_scope(response.verification)
|
||||
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def upload_bundle(
|
||||
tenant_id: str, bundle: bytes, verify_signature: bool = False
|
||||
) -> Sequence[PluginBundleDependency]:
|
||||
"""
|
||||
Upload a plugin bundle and return the dependencies.
|
||||
"""
|
||||
manager = PluginInstaller()
|
||||
PluginService._check_marketplace_only_permission()
|
||||
return manager.upload_bundle(tenant_id, bundle, verify_signature)
|
||||
|
||||
@staticmethod
|
||||
def install_from_local_pkg(tenant_id: str, plugin_unique_identifiers: Sequence[str]):
|
||||
PluginService._check_marketplace_only_permission()
|
||||
|
||||
manager = PluginInstaller()
|
||||
|
||||
for plugin_unique_identifier in plugin_unique_identifiers:
|
||||
resp = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier)
|
||||
PluginService._check_plugin_installation_scope(resp.verification)
|
||||
|
||||
return manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
plugin_unique_identifiers,
|
||||
PluginInstallationSource.Package,
|
||||
[{}],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def install_from_github(tenant_id: str, plugin_unique_identifier: str, repo: str, version: str, package: str):
|
||||
"""
|
||||
Install plugin from github release package files,
|
||||
returns plugin_unique_identifier
|
||||
"""
|
||||
PluginService._check_marketplace_only_permission()
|
||||
|
||||
manager = PluginInstaller()
|
||||
plugin_decode_response = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier)
|
||||
PluginService._check_plugin_installation_scope(plugin_decode_response.verification)
|
||||
|
||||
return manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
[plugin_unique_identifier],
|
||||
PluginInstallationSource.Github,
|
||||
[
|
||||
{
|
||||
"repo": repo,
|
||||
"version": version,
|
||||
"package": package,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def fetch_marketplace_pkg(tenant_id: str, plugin_unique_identifier: str) -> PluginDeclaration:
|
||||
"""
|
||||
Fetch marketplace package
|
||||
"""
|
||||
if not dify_config.MARKETPLACE_ENABLED:
|
||||
raise ValueError("marketplace is not enabled")
|
||||
|
||||
features = FeatureService.get_system_features()
|
||||
|
||||
manager = PluginInstaller()
|
||||
try:
|
||||
declaration = manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
|
||||
except Exception:
|
||||
pkg = download_plugin_pkg(plugin_unique_identifier)
|
||||
response = manager.upload_pkg(
|
||||
tenant_id,
|
||||
pkg,
|
||||
verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only,
|
||||
)
|
||||
# check if the plugin is available to install
|
||||
PluginService._check_plugin_installation_scope(response.verification)
|
||||
declaration = response.manifest
|
||||
|
||||
return declaration
|
||||
|
||||
@staticmethod
|
||||
def install_from_marketplace_pkg(tenant_id: str, plugin_unique_identifiers: Sequence[str]):
|
||||
"""
|
||||
Install plugin from marketplace package files,
|
||||
returns installation task id
|
||||
"""
|
||||
if not dify_config.MARKETPLACE_ENABLED:
|
||||
raise ValueError("marketplace is not enabled")
|
||||
|
||||
manager = PluginInstaller()
|
||||
|
||||
# collect actual plugin_unique_identifiers
|
||||
actual_plugin_unique_identifiers = []
|
||||
metas = []
|
||||
features = FeatureService.get_system_features()
|
||||
|
||||
# check if already downloaded
|
||||
for plugin_unique_identifier in plugin_unique_identifiers:
|
||||
try:
|
||||
manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
|
||||
plugin_decode_response = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier)
|
||||
# check if the plugin is available to install
|
||||
PluginService._check_plugin_installation_scope(plugin_decode_response.verification)
|
||||
# already downloaded, skip
|
||||
actual_plugin_unique_identifiers.append(plugin_unique_identifier)
|
||||
metas.append({"plugin_unique_identifier": plugin_unique_identifier})
|
||||
except Exception:
|
||||
# plugin not installed, download and upload pkg
|
||||
pkg = download_plugin_pkg(plugin_unique_identifier)
|
||||
response = manager.upload_pkg(
|
||||
tenant_id,
|
||||
pkg,
|
||||
verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only,
|
||||
)
|
||||
# check if the plugin is available to install
|
||||
PluginService._check_plugin_installation_scope(response.verification)
|
||||
# use response plugin_unique_identifier
|
||||
actual_plugin_unique_identifiers.append(response.unique_identifier)
|
||||
metas.append({"plugin_unique_identifier": response.unique_identifier})
|
||||
|
||||
return manager.install_from_identifiers(
|
||||
tenant_id,
|
||||
actual_plugin_unique_identifiers,
|
||||
PluginInstallationSource.Marketplace,
|
||||
metas,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
|
||||
manager = PluginInstaller()
|
||||
return manager.uninstall(tenant_id, plugin_installation_id)
|
||||
|
||||
@staticmethod
|
||||
def check_tools_existence(tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]:
|
||||
"""
|
||||
Check if the tools exist
|
||||
"""
|
||||
manager = PluginInstaller()
|
||||
return manager.check_tools_existence(tenant_id, provider_ids)
|
||||
|
||||
@staticmethod
|
||||
def fetch_plugin_readme(tenant_id: str, plugin_unique_identifier: str, language: str) -> str:
|
||||
"""
|
||||
Fetch plugin readme
|
||||
"""
|
||||
manager = PluginInstaller()
|
||||
return manager.fetch_plugin_readme(tenant_id, plugin_unique_identifier, language)
|
||||
Reference in New Issue
Block a user