import json import logging import re from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError from pydantic import BaseModel, ConfigDict, Field, model_validator from sqlalchemy import func, select from sqlalchemy.orm import Session from constants import HIDDEN_VALUE from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity from core.entities.provider_entities import ( CustomConfiguration, ModelSettings, SystemConfiguration, SystemConfigurationStatus, ) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from core.model_runtime.entities.provider_entities import ( ConfigurateMethod, CredentialFormSchema, FormType, ProviderEntity, ) from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.provider import ( LoadBalancingModelConfig, Provider, ProviderCredential, ProviderModel, ProviderModelCredential, ProviderModelSetting, ProviderType, TenantPreferredModelProvider, ) from models.provider_ids import ModelProviderID from services.enterprise.plugin_manager_service import PluginCredentialType logger = logging.getLogger(__name__) original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {} class ProviderConfiguration(BaseModel): """ Provider configuration entity for managing model provider settings. This class handles: - Provider credentials CRUD and switch - Custom Model credentials CRUD and switch - System vs custom provider switching - Load balancing configurations - Model enablement/disablement TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified """ tenant_id: str provider: ProviderEntity preferred_provider_type: ProviderType using_provider_type: ProviderType system_configuration: SystemConfiguration custom_configuration: CustomConfiguration model_settings: list[ModelSettings] # pydantic configs model_config = ConfigDict(protected_namespaces=()) @model_validator(mode="after") def _(self): if self.provider.provider not in original_provider_configurate_methods: original_provider_configurate_methods[self.provider.provider] = [] for configurate_method in self.provider.configurate_methods: original_provider_configurate_methods[self.provider.provider].append(configurate_method) if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: if ( any( len(quota_configuration.restrict_models) > 0 for quota_configuration in self.system_configuration.quota_configurations ) and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods ): self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) return self def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None: """ Get current credentials. :param model_type: model type :param model: model name :return: """ if self.model_settings: # check if model is disabled by admin for model_setting in self.model_settings: if model_setting.model_type == model_type and model_setting.model == model: if not model_setting.enabled: raise ValueError(f"Model {model} is disabled.") if self.using_provider_type == ProviderType.SYSTEM: restrict_models = [] for quota_configuration in self.system_configuration.quota_configurations: if self.system_configuration.current_quota_type != quota_configuration.quota_type: continue restrict_models = quota_configuration.restrict_models copy_credentials = ( self.system_configuration.credentials.copy() if self.system_configuration.credentials else {} ) if restrict_models: for restrict_model in restrict_models: if ( restrict_model.model_type == model_type and restrict_model.model == model and restrict_model.base_model_name ): copy_credentials["base_model_name"] = restrict_model.base_model_name return copy_credentials else: credentials = None current_credential_id = None if self.custom_configuration.models: for model_configuration in self.custom_configuration.models: if model_configuration.model_type == model_type and model_configuration.model == model: credentials = model_configuration.credentials current_credential_id = model_configuration.current_credential_id break if not credentials and self.custom_configuration.provider: credentials = self.custom_configuration.provider.credentials current_credential_id = self.custom_configuration.provider.current_credential_id if current_credential_id: from core.helper.credential_utils import check_credential_policy_compliance check_credential_policy_compliance( credential_id=current_credential_id, provider=self.provider.provider, credential_type=PluginCredentialType.MODEL, ) else: # no current credential id, check all available credentials if self.custom_configuration.provider: for credential_configuration in self.custom_configuration.provider.available_credentials: from core.helper.credential_utils import check_credential_policy_compliance check_credential_policy_compliance( credential_id=credential_configuration.credential_id, provider=self.provider.provider, credential_type=PluginCredentialType.MODEL, ) return credentials def get_system_configuration_status(self) -> SystemConfigurationStatus | None: """ Get system configuration status. :return: """ if self.system_configuration.enabled is False: return SystemConfigurationStatus.UNSUPPORTED current_quota_type = self.system_configuration.current_quota_type current_quota_configuration = next( (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None ) if current_quota_configuration is None: return None if not current_quota_configuration: return SystemConfigurationStatus.UNSUPPORTED return ( SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else SystemConfigurationStatus.QUOTA_EXCEEDED ) def is_custom_configuration_available(self) -> bool: """ Check custom configuration available. :return: """ has_provider_credentials = ( self.custom_configuration.provider is not None and len(self.custom_configuration.provider.available_credentials) > 0 ) has_model_configurations = len(self.custom_configuration.models) > 0 return has_provider_credentials or has_model_configurations def _get_provider_record(self, session: Session) -> Provider | None: """ Get custom provider record. """ stmt = select(Provider).where( Provider.tenant_id == self.tenant_id, Provider.provider_type == ProviderType.CUSTOM, Provider.provider_name.in_(self._get_provider_names()), ) return session.execute(stmt).scalar_one_or_none() def _get_specific_provider_credential(self, credential_id: str) -> dict | None: """ Get a specific provider credential by ID. :param credential_id: Credential ID :return: """ # Extract secret variables from provider credential schema credential_secret_variables = self.extract_secret_variables( self.provider.provider_credential_schema.credential_form_schemas if self.provider.provider_credential_schema else [] ) with Session(db.engine) as session: # Prefer the actual provider record name if exists (to handle aliased provider names) provider_record = self._get_provider_record(session) provider_name = provider_record.provider_name if provider_record else self.provider.provider stmt = select(ProviderCredential).where( ProviderCredential.id == credential_id, ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name == provider_name, ) credential = session.execute(stmt).scalar_one_or_none() if not credential or not credential.encrypted_config: raise ValueError(f"Credential with id {credential_id} not found.") try: credentials = json.loads(credential.encrypted_config) except JSONDecodeError: credentials = {} # Decrypt secret variables for key in credential_secret_variables: if key in credentials and credentials[key] is not None: try: credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key]) except Exception: pass return self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas if self.provider.provider_credential_schema else [], ) def _check_provider_credential_name_exists( self, credential_name: str, session: Session, exclude_id: str | None = None ) -> bool: """ not allowed same name when create or update a credential """ stmt = select(ProviderCredential.id).where( ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name.in_(self._get_provider_names()), ProviderCredential.credential_name == credential_name, ) if exclude_id: stmt = stmt.where(ProviderCredential.id != exclude_id) return session.execute(stmt).scalar_one_or_none() is not None def get_provider_credential(self, credential_id: str | None = None) -> dict | None: """ Get provider credentials. :param credential_id: if provided, return the specified credential :return: """ if credential_id: return self._get_specific_provider_credential(credential_id) # Default behavior: return current active provider credentials credentials = self.custom_configuration.provider.credentials if self.custom_configuration.provider else {} return self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas if self.provider.provider_credential_schema else [], ) def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None): """ Validate custom credentials. :param credentials: provider credentials :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate :param session: optional database session :return: """ def _validate(s: Session): # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( self.provider.provider_credential_schema.credential_form_schemas if self.provider.provider_credential_schema else [] ) if credential_id: try: stmt = select(ProviderCredential).where( ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name.in_(self._get_provider_names()), ProviderCredential.id == credential_id, ) credential_record = s.execute(stmt).scalar_one_or_none() # fix origin data if credential_record and credential_record.encrypted_config: if not credential_record.encrypted_config.startswith("{"): original_credentials = {"openai_api_key": credential_record.encrypted_config} else: original_credentials = json.loads(credential_record.encrypted_config) else: original_credentials = {} except JSONDecodeError: original_credentials = {} # encrypt credentials for key, value in credentials.items(): if key in provider_credential_secret_variables: # if send [__HIDDEN__] in secret input, it will be same as original value if value == HIDDEN_VALUE and key in original_credentials: credentials[key] = encrypter.decrypt_token( tenant_id=self.tenant_id, token=original_credentials[key] ) model_provider_factory = ModelProviderFactory(self.tenant_id) validated_credentials = model_provider_factory.provider_credentials_validate( provider=self.provider.provider, credentials=credentials ) for key, value in validated_credentials.items(): if key in provider_credential_secret_variables: validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) return validated_credentials if session: return _validate(session) else: with Session(db.engine) as new_session: return _validate(new_session) def _generate_provider_credential_name(self, session) -> str: """ Generate a unique credential name for provider. :return: credential name """ return self._generate_next_api_key_name( session=session, query_factory=lambda: select(ProviderCredential).where( ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name.in_(self._get_provider_names()), ), ) def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str: """ Generate a unique credential name for custom model. :return: credential name """ return self._generate_next_api_key_name( session=session, query_factory=lambda: select(ProviderModelCredential).where( ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ), ) def _generate_next_api_key_name(self, session, query_factory) -> str: """ Generate next available API KEY name by finding the highest numbered suffix. :param session: database session :param query_factory: function that returns the SQLAlchemy query :return: next available API KEY name """ try: stmt = query_factory() credential_records = session.execute(stmt).scalars().all() if not credential_records: return "API KEY 1" # Extract numbers from API KEY pattern using list comprehension pattern = re.compile(r"^API KEY\s+(\d+)$") numbers = [ int(match.group(1)) for cr in credential_records if cr.credential_name and (match := pattern.match(cr.credential_name.strip())) ] # Return next sequential number next_number = max(numbers, default=0) + 1 return f"API KEY {next_number}" except Exception as e: logger.warning("Error generating next credential name: %s", str(e)) return "API KEY 1" def _get_provider_names(self): """ The provider name might be stored in the database as either `openai` or `langgenius/openai/openai`. """ model_provider_id = ModelProviderID(self.provider.provider) provider_names = [self.provider.provider] if model_provider_id.is_langgenius(): provider_names.append(model_provider_id.provider_name) return provider_names def create_provider_credential(self, credentials: dict, credential_name: str | None): """ Add custom provider credentials. :param credentials: provider credentials :param credential_name: credential name :return: """ with Session(db.engine) as session: if credential_name: if self._check_provider_credential_name_exists(credential_name=credential_name, session=session): raise ValueError(f"Credential with name '{credential_name}' already exists.") else: credential_name = self._generate_provider_credential_name(session) credentials = self.validate_provider_credentials(credentials=credentials, session=session) provider_record = self._get_provider_record(session) try: new_record = ProviderCredential( tenant_id=self.tenant_id, provider_name=self.provider.provider, encrypted_config=json.dumps(credentials), credential_name=credential_name, ) session.add(new_record) session.flush() if not provider_record: # If provider record does not exist, create it provider_record = Provider( tenant_id=self.tenant_id, provider_name=self.provider.provider, provider_type=ProviderType.CUSTOM, is_valid=True, credential_id=new_record.id, ) session.add(provider_record) provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER, ) provider_model_credentials_cache.delete() self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session) else: # some historical data may have a provider record but not be set as valid provider_record.is_valid = True session.commit() except Exception: session.rollback() raise def update_provider_credential( self, credentials: dict, credential_id: str, credential_name: str | None, ): """ update a saved provider credential (by credential_id). :param credentials: provider credentials :param credential_id: credential id :param credential_name: credential name :return: """ with Session(db.engine) as session: if credential_name and self._check_provider_credential_name_exists( credential_name=credential_name, session=session, exclude_id=credential_id ): raise ValueError(f"Credential with name '{credential_name}' already exists.") credentials = self.validate_provider_credentials( credentials=credentials, credential_id=credential_id, session=session ) provider_record = self._get_provider_record(session) stmt = select(ProviderCredential).where( ProviderCredential.id == credential_id, ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name.in_(self._get_provider_names()), ) # Get the credential record to update credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: raise ValueError("Credential record not found.") try: # Update credential credential_record.encrypted_config = json.dumps(credentials) credential_record.updated_at = naive_utc_now() if credential_name: credential_record.credential_name = credential_name session.commit() if provider_record and provider_record.credential_id == credential_id: provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER, ) provider_model_credentials_cache.delete() self._update_load_balancing_configs_with_credential( credential_id=credential_id, credential_record=credential_record, credential_source="provider", session=session, ) except Exception: session.rollback() raise def _update_load_balancing_configs_with_credential( self, credential_id: str, credential_record: ProviderCredential | ProviderModelCredential, credential_source: str, session: Session, ): """ Update load balancing configurations that reference the given credential_id. :param credential_id: credential id :param credential_record: the encrypted_config and credential_name :param credential_source: the credential comes from the provider_credential(`provider`) or the provider_model_credential(`custom_model`) :param session: the database session :return: """ # Find all load balancing configs that use this credential_id stmt = select(LoadBalancingModelConfig).where( LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_source_type == credential_source, ) load_balancing_configs = session.execute(stmt).scalars().all() if not load_balancing_configs: return # Update each load balancing config with the new credentials for lb_config in load_balancing_configs: # Update the encrypted_config with the new credentials lb_config.encrypted_config = credential_record.encrypted_config lb_config.name = credential_record.credential_name lb_config.updated_at = naive_utc_now() # Clear cache for this load balancing config lb_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=lb_config.id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, ) lb_credentials_cache.delete() session.commit() def delete_provider_credential(self, credential_id: str): """ Delete a saved provider credential (by credential_id). :param credential_id: credential id :return: """ with Session(db.engine) as session: stmt = select(ProviderCredential).where( ProviderCredential.id == credential_id, ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name.in_(self._get_provider_names()), ) # Get the credential record to update credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: raise ValueError("Credential record not found.") # Check if this credential is used in load balancing configs lb_stmt = select(LoadBalancingModelConfig).where( LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_source_type == "provider", ) lb_configs_using_credential = session.execute(lb_stmt).scalars().all() try: for lb_config in lb_configs_using_credential: lb_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=lb_config.id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, ) lb_credentials_cache.delete() session.delete(lb_config) # Check if this is the currently active credential provider_record = self._get_provider_record(session) # Check available credentials count BEFORE deleting # if this is the last credential, we need to delete the provider record count_stmt = select(func.count(ProviderCredential.id)).where( ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name.in_(self._get_provider_names()), ) available_credentials_count = session.execute(count_stmt).scalar() or 0 session.delete(credential_record) if provider_record and available_credentials_count <= 1: # If all credentials are deleted, delete the provider record session.delete(provider_record) provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER, ) provider_model_credentials_cache.delete() self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session) elif provider_record and provider_record.credential_id == credential_id: provider_record.credential_id = None provider_record.updated_at = naive_utc_now() provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER, ) provider_model_credentials_cache.delete() self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session) session.commit() except Exception: session.rollback() raise def switch_active_provider_credential(self, credential_id: str): """ Switch active provider credential (copy the selected one into current active snapshot). :param credential_id: credential id :return: """ with Session(db.engine) as session: stmt = select(ProviderCredential).where( ProviderCredential.id == credential_id, ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name.in_(self._get_provider_names()), ) credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: raise ValueError("Credential record not found.") provider_record = self._get_provider_record(session) if not provider_record: raise ValueError("Provider record not found.") try: provider_record.credential_id = credential_record.id provider_record.updated_at = naive_utc_now() session.commit() provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER, ) provider_model_credentials_cache.delete() self.switch_preferred_provider_type(ProviderType.CUSTOM, session=session) except Exception: session.rollback() raise def _get_custom_model_record( self, model_type: ModelType, model: str, session: Session, ) -> ProviderModel | None: """ Get custom model credentials. """ # get provider model model_provider_id = ModelProviderID(self.provider.provider) provider_names = [self.provider.provider] if model_provider_id.is_langgenius(): provider_names.append(model_provider_id.provider_name) stmt = select(ProviderModel).where( ProviderModel.tenant_id == self.tenant_id, ProviderModel.provider_name.in_(provider_names), ProviderModel.model_name == model, ProviderModel.model_type == model_type.to_origin_model_type(), ) return session.execute(stmt).scalar_one_or_none() def _get_specific_custom_model_credential( self, model_type: ModelType, model: str, credential_id: str ) -> dict | None: """ Get a specific provider credential by ID. :param credential_id: Credential ID :return: """ model_credential_secret_variables = self.extract_secret_variables( self.provider.model_credential_schema.credential_form_schemas if self.provider.model_credential_schema else [] ) with Session(db.engine) as session: stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record or not credential_record.encrypted_config: raise ValueError(f"Credential with id {credential_id} not found.") try: credentials = json.loads(credential_record.encrypted_config) except JSONDecodeError: credentials = {} # Decrypt secret variables for key in model_credential_secret_variables: if key in credentials and credentials[key] is not None: try: credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key]) except Exception: pass current_credential_id = credential_record.id current_credential_name = credential_record.credential_name credentials = self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas if self.provider.model_credential_schema else [], ) return { "current_credential_id": current_credential_id, "current_credential_name": current_credential_name, "credentials": credentials, } def _check_custom_model_credential_name_exists( self, model_type: ModelType, model: str, credential_name: str, session: Session, exclude_id: str | None = None ) -> bool: """ not allowed same name when create or update a credential """ stmt = select(ProviderModelCredential).where( ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.credential_name == credential_name, ) if exclude_id: stmt = stmt.where(ProviderModelCredential.id != exclude_id) return session.execute(stmt).scalar_one_or_none() is not None def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None: """ Get custom model credentials. :param model_type: model type :param model: model name :return: """ # If credential_id is provided, return the specific credential if credential_id: return self._get_specific_custom_model_credential( model_type=model_type, model=model, credential_id=credential_id ) for model_configuration in self.custom_configuration.models: if ( model_configuration.model_type == model_type and model_configuration.model == model and model_configuration.credentials ): current_credential_id = model_configuration.current_credential_id current_credential_name = model_configuration.current_credential_name credentials = self.obfuscated_credentials( credentials=model_configuration.credentials, credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas if self.provider.model_credential_schema else [], ) return { "current_credential_id": current_credential_id, "current_credential_name": current_credential_name, "credentials": credentials, } return None def validate_custom_model_credentials( self, model_type: ModelType, model: str, credentials: dict, credential_id: str = "", session: Session | None = None, ): """ Validate custom model credentials. :param model_type: model type :param model: model name :param credentials: model credentials dict :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate :return: """ def _validate(s: Session): # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( self.provider.model_credential_schema.credential_form_schemas if self.provider.model_credential_schema else [] ) if credential_id: try: stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) credential_record = s.execute(stmt).scalar_one_or_none() original_credentials = ( json.loads(credential_record.encrypted_config) if credential_record and credential_record.encrypted_config else {} ) except JSONDecodeError: original_credentials = {} # decrypt credentials for key, value in credentials.items(): if key in provider_credential_secret_variables: # if send [__HIDDEN__] in secret input, it will be same as original value if value == HIDDEN_VALUE and key in original_credentials: credentials[key] = encrypter.decrypt_token( tenant_id=self.tenant_id, token=original_credentials[key] ) model_provider_factory = ModelProviderFactory(self.tenant_id) validated_credentials = model_provider_factory.model_credentials_validate( provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) for key, value in validated_credentials.items(): if key in provider_credential_secret_variables: validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) return validated_credentials if session: return _validate(session) else: with Session(db.engine) as new_session: return _validate(new_session) def create_custom_model_credential( self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None ) -> None: """ Create a custom model credential. :param model_type: model type :param model: model name :param credentials: model credentials dict :return: """ with Session(db.engine) as session: if credential_name: if self._check_custom_model_credential_name_exists( model=model, model_type=model_type, credential_name=credential_name, session=session ): raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") else: credential_name = self._generate_custom_model_credential_name( model=model, model_type=model_type, session=session ) # validate custom model config credentials = self.validate_custom_model_credentials( model_type=model_type, model=model, credentials=credentials, session=session ) provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) try: credential = ProviderModelCredential( tenant_id=self.tenant_id, provider_name=self.provider.provider, model_name=model, model_type=model_type.to_origin_model_type(), encrypted_config=json.dumps(credentials), credential_name=credential_name, ) session.add(credential) session.flush() # save provider model if not provider_model_record: provider_model_record = ProviderModel( tenant_id=self.tenant_id, provider_name=self.provider.provider, model_name=model, model_type=model_type.to_origin_model_type(), credential_id=credential.id, is_valid=True, ) session.add(provider_model_record) session.commit() provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL, ) provider_model_credentials_cache.delete() except Exception: session.rollback() raise def update_custom_model_credential( self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str ) -> None: """ Update a custom model credential. :param model_type: model type :param model: model name :param credentials: model credentials dict :param credential_name: credential name :param credential_id: credential id :return: """ with Session(db.engine) as session: if credential_name and self._check_custom_model_credential_name_exists( model=model, model_type=model_type, credential_name=credential_name, session=session, exclude_id=credential_id, ): raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") # validate custom model config credentials = self.validate_custom_model_credentials( model_type=model_type, model=model, credentials=credentials, credential_id=credential_id, session=session, ) provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: raise ValueError("Credential record not found.") try: # Update credential credential_record.encrypted_config = json.dumps(credentials) credential_record.updated_at = naive_utc_now() if credential_name: credential_record.credential_name = credential_name session.commit() if provider_model_record and provider_model_record.credential_id == credential_id: provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL, ) provider_model_credentials_cache.delete() self._update_load_balancing_configs_with_credential( credential_id=credential_id, credential_record=credential_record, credential_source="custom_model", session=session, ) except Exception: session.rollback() raise def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str): """ Delete a saved provider credential (by credential_id). :param credential_id: credential id :return: """ with Session(db.engine) as session: stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: raise ValueError("Credential record not found.") lb_stmt = select(LoadBalancingModelConfig).where( LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_source_type == "custom_model", ) lb_configs_using_credential = session.execute(lb_stmt).scalars().all() try: for lb_config in lb_configs_using_credential: lb_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=lb_config.id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, ) lb_credentials_cache.delete() session.delete(lb_config) # Check if this is the currently active credential provider_model_record = self._get_custom_model_record(model_type, model, session=session) # Check available credentials count BEFORE deleting # if this is the last credential, we need to delete the custom model record count_stmt = select(func.count(ProviderModelCredential.id)).where( ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) available_credentials_count = session.execute(count_stmt).scalar() or 0 session.delete(credential_record) if provider_model_record and available_credentials_count <= 1: # If all credentials are deleted, delete the custom model record session.delete(provider_model_record) elif provider_model_record and provider_model_record.credential_id == credential_id: provider_model_record.credential_id = None provider_model_record.updated_at = naive_utc_now() provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER, ) provider_model_credentials_cache.delete() session.commit() except Exception: session.rollback() raise def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str): """ if model list exist this custom model, switch the custom model credential. if model list not exist this custom model, use the credential to add a new custom model record. :param model_type: model type :param model: model name :param credential_id: credential id :return: """ with Session(db.engine) as session: stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: raise ValueError("Credential record not found.") # validate custom model config provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) if not provider_model_record: # create provider model record provider_model_record = ProviderModel( tenant_id=self.tenant_id, provider_name=self.provider.provider, model_name=model, model_type=model_type.to_origin_model_type(), is_valid=True, credential_id=credential_id, ) else: if provider_model_record.credential_id == credential_record.id: raise ValueError("Can't add same credential") provider_model_record.credential_id = credential_record.id provider_model_record.updated_at = naive_utc_now() # clear cache provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL, ) provider_model_credentials_cache.delete() session.add(provider_model_record) session.commit() def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str): """ switch the custom model credential. :param model_type: model type :param model: model name :param credential_id: credential id :return: """ with Session(db.engine) as session: stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: raise ValueError("Credential record not found.") provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) if not provider_model_record: raise ValueError("The custom model record not found.") provider_model_record.credential_id = credential_record.id provider_model_record.updated_at = naive_utc_now() session.add(provider_model_record) session.commit() # clear cache provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL, ) provider_model_credentials_cache.delete() def delete_custom_model(self, model_type: ModelType, model: str): """ Delete custom model. :param model_type: model type :param model: model name :return: """ with Session(db.engine) as session: # get provider model provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) # delete provider model if provider_model_record: session.delete(provider_model_record) session.commit() provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL, ) provider_model_credentials_cache.delete() def _get_provider_model_setting( self, model_type: ModelType, model: str, session: Session ) -> ProviderModelSetting | None: """ Get provider model setting. """ stmt = select(ProviderModelSetting).where( ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.provider_name.in_(self._get_provider_names()), ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_name == model, ) return session.execute(stmt).scalars().first() def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ Enable model. :param model_type: model type :param model: model name :return: """ with Session(db.engine) as session: model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session) if model_setting: model_setting.enabled = True model_setting.updated_at = naive_utc_now() else: model_setting = ProviderModelSetting( tenant_id=self.tenant_id, provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, enabled=True, ) session.add(model_setting) session.commit() return model_setting def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ Disable model. :param model_type: model type :param model: model name :return: """ with Session(db.engine) as session: model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session) if model_setting: model_setting.enabled = False model_setting.updated_at = naive_utc_now() else: model_setting = ProviderModelSetting( tenant_id=self.tenant_id, provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, enabled=False, ) session.add(model_setting) session.commit() return model_setting def get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None: """ Get provider model setting. :param model_type: model type :param model: model name :return: """ with Session(db.engine) as session: return self._get_provider_model_setting(model_type=model_type, model=model, session=session) def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ Enable model load balancing. :param model_type: model type :param model: model name :return: """ model_provider_id = ModelProviderID(self.provider.provider) provider_names = [self.provider.provider] if model_provider_id.is_langgenius(): provider_names.append(model_provider_id.provider_name) with Session(db.engine) as session: stmt = select(func.count(LoadBalancingModelConfig.id)).where( LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(provider_names), LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, ) load_balancing_config_count = session.execute(stmt).scalar() or 0 if load_balancing_config_count <= 1: raise ValueError("Model load balancing configuration must be more than 1.") model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session) if model_setting: model_setting.load_balancing_enabled = True model_setting.updated_at = naive_utc_now() else: model_setting = ProviderModelSetting( tenant_id=self.tenant_id, provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, load_balancing_enabled=True, ) session.add(model_setting) session.commit() return model_setting def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ Disable model load balancing. :param model_type: model type :param model: model name :return: """ with Session(db.engine) as session: model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session) if model_setting: model_setting.load_balancing_enabled = False model_setting.updated_at = naive_utc_now() else: model_setting = ProviderModelSetting( tenant_id=self.tenant_id, provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, load_balancing_enabled=False, ) session.add(model_setting) session.commit() return model_setting def get_model_type_instance(self, model_type: ModelType) -> AIModel: """ Get current model type instance. :param model_type: model type :return: """ model_provider_factory = ModelProviderFactory(self.tenant_id) # Get model instance of LLM return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type) def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None: """ Get model schema """ model_provider_factory = ModelProviderFactory(self.tenant_id) return model_provider_factory.get_model_schema( provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None): """ Switch preferred provider type. :param provider_type: :return: """ if provider_type == self.preferred_provider_type: return if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled: return def _switch(s: Session): stmt = select(TenantPreferredModelProvider).where( TenantPreferredModelProvider.tenant_id == self.tenant_id, TenantPreferredModelProvider.provider_name.in_(self._get_provider_names()), ) preferred_model_provider = s.execute(stmt).scalars().first() if preferred_model_provider: preferred_model_provider.preferred_provider_type = provider_type.value else: preferred_model_provider = TenantPreferredModelProvider( tenant_id=self.tenant_id, provider_name=self.provider.provider, preferred_provider_type=provider_type.value, ) s.add(preferred_model_provider) s.commit() if session: return _switch(session) else: with Session(db.engine) as session: return _switch(session) def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: """ Extract secret input form variables. :param credential_form_schemas: :return: """ secret_input_form_variables = [] for credential_form_schema in credential_form_schemas: if credential_form_schema.type == FormType.SECRET_INPUT: secret_input_form_variables.append(credential_form_schema.variable) return secret_input_form_variables def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]): """ Obfuscated credentials. :param credentials: credentials :param credential_form_schemas: credential form schemas :return: """ # Get provider credential secret variables credential_secret_variables = self.extract_secret_variables(credential_form_schemas) # Obfuscate provider credentials copy_credentials = credentials.copy() for key, value in copy_credentials.items(): if key in credential_secret_variables: copy_credentials[key] = encrypter.obfuscated_token(value) return copy_credentials def get_provider_model( self, model_type: ModelType, model: str, only_active: bool = False ) -> ModelWithProviderEntity | None: """ Get provider model. :param model_type: model type :param model: model name :param only_active: return active model only :return: """ provider_models = self.get_provider_models(model_type, only_active, model) for provider_model in provider_models: if provider_model.model == model: return provider_model return None def get_provider_models( self, model_type: ModelType | None = None, only_active: bool = False, model: str | None = None ) -> list[ModelWithProviderEntity]: """ Get provider models. :param model_type: model type :param only_active: only active models :param model: model name :return: """ model_provider_factory = ModelProviderFactory(self.tenant_id) provider_schema = model_provider_factory.get_provider_schema(self.provider.provider) model_types: list[ModelType] = [] if model_type: model_types.append(model_type) else: model_types = list(provider_schema.supported_model_types) # Group model settings by model type and model model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict) for model_setting in self.model_settings: model_setting_map[model_setting.model_type][model_setting.model] = model_setting if self.using_provider_type == ProviderType.SYSTEM: provider_models = self._get_system_provider_models( model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map ) else: provider_models = self._get_custom_provider_models( model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map, model=model, ) if only_active: provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE] # resort provider_models # Optimize sorting logic: first sort by provider.position order, then by model_type.value # Get the position list for model types (retrieve only once for better performance) model_type_positions = {} if hasattr(self.provider, "position") and self.provider.position: model_type_positions = self.provider.position def get_sort_key(model: ModelWithProviderEntity): # Get the position list for the current model type positions = model_type_positions.get(model.model_type.value, []) # If the model name is in the position list, use its index for sorting # Otherwise use a large value (list length) to place undefined models at the end position_index = positions.index(model.model) if model.model in positions else len(positions) # Return composite sort key: (model_type value, model position index) return (model.model_type.value, position_index) # Deduplicate provider_models = list({(m.model, m.model_type, m.fetch_from): m for m in provider_models}.values()) # Sort using the composite sort key return sorted(provider_models, key=get_sort_key) def _get_system_provider_models( self, model_types: Sequence[ModelType], provider_schema: ProviderEntity, model_setting_map: dict[ModelType, dict[str, ModelSettings]], ) -> list[ModelWithProviderEntity]: """ Get system provider models. :param model_types: model types :param provider_schema: provider schema :param model_setting_map: model setting map :return: """ provider_models = [] for model_type in model_types: for m in provider_schema.models: if m.model_type != model_type: continue status = ModelStatus.ACTIVE if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: model_setting = model_setting_map[m.model_type][m.model] if model_setting.enabled is False: status = ModelStatus.DISABLED provider_models.append( ModelWithProviderEntity( model=m.model, label=m.label, model_type=m.model_type, features=m.features, fetch_from=m.fetch_from, model_properties=m.model_properties, deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), status=status, ) ) if self.provider.provider not in original_provider_configurate_methods: original_provider_configurate_methods[self.provider.provider] = [] for configurate_method in provider_schema.configurate_methods: original_provider_configurate_methods[self.provider.provider].append(configurate_method) should_use_custom_model = False if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: should_use_custom_model = True for quota_configuration in self.system_configuration.quota_configurations: if self.system_configuration.current_quota_type != quota_configuration.quota_type: continue restrict_models = quota_configuration.restrict_models if len(restrict_models) == 0: break if should_use_custom_model: if original_provider_configurate_methods[self.provider.provider] == [ ConfigurateMethod.CUSTOMIZABLE_MODEL ]: # only customizable model for restrict_model in restrict_models: copy_credentials = ( self.system_configuration.credentials.copy() if self.system_configuration.credentials else {} ) if restrict_model.base_model_name: copy_credentials["base_model_name"] = restrict_model.base_model_name try: custom_model_schema = self.get_model_schema( model_type=restrict_model.model_type, model=restrict_model.model, credentials=copy_credentials, ) except Exception as ex: logger.warning("get custom model schema failed, %s", ex) continue if not custom_model_schema: continue if custom_model_schema.model_type not in model_types: continue status = ModelStatus.ACTIVE if ( custom_model_schema.model_type in model_setting_map and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] ): model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] if model_setting.enabled is False: status = ModelStatus.DISABLED provider_models.append( ModelWithProviderEntity( model=custom_model_schema.model, label=custom_model_schema.label, model_type=custom_model_schema.model_type, features=custom_model_schema.features, fetch_from=FetchFrom.PREDEFINED_MODEL, model_properties=custom_model_schema.model_properties, deprecated=custom_model_schema.deprecated, provider=SimpleModelProviderEntity(self.provider), status=status, ) ) # if llm name not in restricted llm list, remove it restrict_model_names = [rm.model for rm in restrict_models] for model in provider_models: if model.model_type == ModelType.LLM and model.model not in restrict_model_names: model.status = ModelStatus.NO_PERMISSION elif not quota_configuration.is_valid: model.status = ModelStatus.QUOTA_EXCEEDED return provider_models def _get_custom_provider_models( self, model_types: Sequence[ModelType], provider_schema: ProviderEntity, model_setting_map: dict[ModelType, dict[str, ModelSettings]], model: str | None = None, ) -> list[ModelWithProviderEntity]: """ Get custom provider models. :param model_types: model types :param provider_schema: provider schema :param model_setting_map: model setting map :return: """ provider_models = [] credentials = None if self.custom_configuration.provider: credentials = self.custom_configuration.provider.credentials for model_type in model_types: if model_type not in self.provider.supported_model_types: continue for m in provider_schema.models: if m.model_type != model_type: continue status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE load_balancing_enabled = False has_invalid_load_balancing_configs = False if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: model_setting = model_setting_map[m.model_type][m.model] if model_setting.enabled is False: status = ModelStatus.DISABLED provider_model_lb_configs = [ config for config in model_setting.load_balancing_configs if config.credential_source_type != "custom_model" ] load_balancing_enabled = model_setting.load_balancing_enabled # when the user enable load_balancing but available configs are less than 2 display warning has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2 provider_models.append( ModelWithProviderEntity( model=m.model, label=m.label, model_type=m.model_type, features=m.features, fetch_from=m.fetch_from, model_properties=m.model_properties, deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), status=status, load_balancing_enabled=load_balancing_enabled, has_invalid_load_balancing_configs=has_invalid_load_balancing_configs, ) ) # custom models for model_configuration in self.custom_configuration.models: if model_configuration.model_type not in model_types: continue if model_configuration.unadded_to_model_list: continue if model and model != model_configuration.model: continue try: custom_model_schema = self.get_model_schema( model_type=model_configuration.model_type, model=model_configuration.model, credentials=model_configuration.credentials, ) except Exception as ex: logger.warning("get custom model schema failed, %s", ex) continue if not custom_model_schema: continue status = ModelStatus.ACTIVE load_balancing_enabled = False has_invalid_load_balancing_configs = False if ( custom_model_schema.model_type in model_setting_map and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] ): model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] if model_setting.enabled is False: status = ModelStatus.DISABLED custom_model_lb_configs = [ config for config in model_setting.load_balancing_configs if config.credential_source_type != "provider" ] load_balancing_enabled = model_setting.load_balancing_enabled # when the user enable load_balancing but available configs are less than 2 display warning has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2 if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials: status = ModelStatus.CREDENTIAL_REMOVED provider_models.append( ModelWithProviderEntity( model=custom_model_schema.model, label=custom_model_schema.label, model_type=custom_model_schema.model_type, features=custom_model_schema.features, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties=custom_model_schema.model_properties, deprecated=custom_model_schema.deprecated, provider=SimpleModelProviderEntity(self.provider), status=status, load_balancing_enabled=load_balancing_enabled, has_invalid_load_balancing_configs=has_invalid_load_balancing_configs, ) ) return provider_models class ProviderConfigurations(BaseModel): """ Model class for provider configuration dict. """ tenant_id: str configurations: dict[str, ProviderConfiguration] = Field(default_factory=dict) def __init__(self, tenant_id: str): super().__init__(tenant_id=tenant_id) def get_models( self, provider: str | None = None, model_type: ModelType | None = None, only_active: bool = False ) -> list[ModelWithProviderEntity]: """ Get available models. If preferred provider type is `system`: Get the current **system mode** if provider supported, if all system modes are not available (no quota), it is considered to be the **custom credential mode**. If there is no model configured in custom mode, it is treated as no_configure. system > custom > no_configure If preferred provider type is `custom`: If custom credentials are configured, it is treated as custom mode. Otherwise, get the current **system mode** if supported, If all system modes are not available (no quota), it is treated as no_configure. custom > system > no_configure If real mode is `system`, use system credentials to get models, paid quotas > provider free quotas > system free quotas include pre-defined models (exclude GPT-4, status marked as `no_permission`). If real mode is `custom`, use workspace custom credentials to get models, include pre-defined models, custom models(manual append). If real mode is `no_configure`, only return pre-defined models from `model runtime`. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`) model status marked as `active` is available. :param provider: provider name :param model_type: model type :param only_active: only active models :return: """ all_models = [] for provider_configuration in self.values(): if provider and provider_configuration.provider.provider != provider: continue all_models.extend(provider_configuration.get_provider_models(model_type, only_active)) return all_models def to_list(self) -> list[ProviderConfiguration]: """ Convert to list. :return: """ return list(self.values()) def __getitem__(self, key): if "/" not in key: key = str(ModelProviderID(key)) return self.configurations[key] def __setitem__(self, key, value): self.configurations[key] = value def __contains__(self, key): if "/" not in key: key = str(ModelProviderID(key)) return key in self.configurations def __iter__(self): # Return an iterator of (key, value) tuples to match BaseModel's __iter__ yield from self.configurations.items() def values(self) -> Iterator[ProviderConfiguration]: return iter(self.configurations.values()) def get(self, key, default=None) -> ProviderConfiguration | None: if "/" not in key: key = str(ModelProviderID(key)) return self.configurations.get(key, default) class ProviderModelBundle(BaseModel): """ Provider model bundle. """ configuration: ProviderConfiguration model_type_instance: AIModel # pydantic configs model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())