package com.iailab.module.ai.service.model;
|
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
import com.iailab.framework.ai.core.enums.AiPlatformEnum;
|
import com.iailab.framework.ai.core.factory.AiModelFactory;
|
import com.iailab.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
import com.iailab.framework.ai.core.model.suno.api.SunoApi;
|
import com.iailab.framework.common.enums.CommonStatusEnum;
|
import com.iailab.framework.common.pojo.PageResult;
|
import com.iailab.framework.common.util.object.BeanUtils;
|
import com.iailab.framework.mybatis.core.query.LambdaQueryWrapperX;
|
import com.iailab.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
|
import com.iailab.module.ai.controller.admin.model.vo.model.AiModelSaveReqVO;
|
import com.iailab.module.ai.dal.dataobject.model.AiApiKeyDO;
|
import com.iailab.module.ai.dal.dataobject.model.AiModelDO;
|
import com.iailab.module.ai.dal.mysql.model.AiChatMapper;
|
import jakarta.annotation.Resource;
|
import org.springframework.ai.chat.model.ChatModel;
|
import org.springframework.ai.embedding.EmbeddingModel;
|
import org.springframework.ai.image.ImageModel;
|
import org.springframework.ai.vectorstore.SimpleVectorStore;
|
import org.springframework.ai.vectorstore.VectorStore;
|
import org.springframework.stereotype.Service;
|
import org.springframework.validation.annotation.Validated;
|
|
import java.util.List;
|
import java.util.Map;
|
|
import static com.iailab.framework.common.exception.util.ServiceExceptionUtil.exception;
|
import static com.iailab.module.ai.enums.ErrorCodeConstants.*;
|
|
/**
|
* AI 模型 Service 实现类
|
*
|
* @author fansili
|
*/
|
@Service
|
@Validated
|
public class AiModelServiceImpl implements AiModelService {
|
|
@Resource
|
private AiApiKeyService apiKeyService;
|
|
@Resource
|
private AiChatMapper modelMapper;
|
|
@Resource
|
private AiModelFactory modelFactory;
|
|
@Override
|
public Long createModel(AiModelSaveReqVO createReqVO) {
|
// 1. 校验
|
AiPlatformEnum.validatePlatform(createReqVO.getPlatform());
|
apiKeyService.validateApiKey(createReqVO.getKeyId());
|
|
//模型名称不能重复
|
List<AiModelDO> aiModelDOS = modelMapper.selectList(new LambdaQueryWrapperX<AiModelDO>().eq(AiModelDO::getName, createReqVO.getName()));
|
if (aiModelDOS.size() > 0) {
|
throw exception(MODEL_NAME_EXISTS);
|
}
|
|
// 2. 插入
|
AiModelDO model = BeanUtils.toBean(createReqVO, AiModelDO.class);
|
modelMapper.insert(model);
|
return model.getId();
|
}
|
|
@Override
|
public void updateModel(AiModelSaveReqVO updateReqVO) {
|
// 1. 校验
|
validateModelExists(updateReqVO.getId());
|
AiPlatformEnum.validatePlatform(updateReqVO.getPlatform());
|
apiKeyService.validateApiKey(updateReqVO.getKeyId());
|
|
//模型名称不能重复
|
List<AiModelDO> aiModelDOS = modelMapper.selectList(new LambdaQueryWrapperX<AiModelDO>()
|
.eq(AiModelDO::getName, updateReqVO.getName())
|
.ne(AiModelDO::getId, updateReqVO.getId()));
|
if (aiModelDOS.size() > 0) {
|
throw exception(MODEL_NAME_EXISTS);
|
}
|
|
// 2. 更新
|
AiModelDO updateObj = BeanUtils.toBean(updateReqVO, AiModelDO.class);
|
modelMapper.updateById(updateObj);
|
}
|
|
@Override
|
public void deleteModel(Long id) {
|
// 校验存在
|
validateModelExists(id);
|
// 删除
|
modelMapper.deleteById(id);
|
}
|
|
private AiModelDO validateModelExists(Long id) {
|
AiModelDO model = modelMapper.selectById(id);
|
if (modelMapper.selectById(id) == null) {
|
throw exception(MODEL_NOT_EXISTS);
|
}
|
return model;
|
}
|
|
@Override
|
public AiModelDO getModel(Long id) {
|
return modelMapper.selectById(id);
|
}
|
|
@Override
|
public AiModelDO getModelByName(String name) {
|
return modelMapper.selectOne(new LambdaQueryWrapper<AiModelDO>().eq(AiModelDO::getName, name));
|
}
|
|
@Override
|
public AiModelDO getRequiredDefaultModel(Integer type) {
|
AiModelDO model = modelMapper.selectFirstByStatus(type, CommonStatusEnum.ENABLE.getStatus());
|
if (model == null) {
|
throw exception(MODEL_DEFAULT_NOT_EXISTS);
|
}
|
return model;
|
}
|
|
@Override
|
public PageResult<AiModelDO> getModelPage(AiModelPageReqVO pageReqVO) {
|
return modelMapper.selectPage(pageReqVO);
|
}
|
|
@Override
|
public AiModelDO validateModel(Long id) {
|
AiModelDO model = validateModelExists(id);
|
if (CommonStatusEnum.isDisable(model.getStatus())) {
|
throw exception(MODEL_DISABLE);
|
}
|
return model;
|
}
|
|
@Override
|
public List<AiModelDO> getModelListByStatusAndType(Integer status, Integer type, String platform) {
|
return modelMapper.selectListByStatusAndType(status, type, platform);
|
}
|
|
// ========== 与 Spring AI 集成 ==========
|
|
@Override
|
public ChatModel getChatModel(Long id) {
|
AiModelDO model = validateModel(id);
|
AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
|
return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl());
|
}
|
|
@Override
|
public ImageModel getImageModel(Long id) {
|
AiModelDO model = validateModel(id);
|
AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
|
return modelFactory.getOrCreateImageModel(platform, apiKey.getApiKey(), apiKey.getUrl());
|
}
|
|
@Override
|
public MidjourneyApi getMidjourneyApi(Long id) {
|
AiModelDO model = validateModel(id);
|
AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
|
return modelFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl());
|
}
|
|
@Override
|
public SunoApi getSunoApi() {
|
AiApiKeyDO apiKey = apiKeyService.getRequiredDefaultApiKey(
|
AiPlatformEnum.SUNO.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
|
return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl());
|
}
|
|
@Override
|
public VectorStore getOrCreateVectorStore(Long id, Map<String, Class<?>> metadataFields) {
|
// 获取模型 + 密钥
|
AiModelDO model = validateModel(id);
|
AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
|
|
// 创建或获取 EmbeddingModel 对象
|
EmbeddingModel embeddingModel = modelFactory.getOrCreateEmbeddingModel(
|
platform, apiKey.getApiKey(), apiKey.getUrl(), model.getModel());
|
|
// 创建或获取 VectorStore 对象
|
return modelFactory.getOrCreateVectorStore(SimpleVectorStore.class, embeddingModel, metadataFields);
|
// return modelFactory.getOrCreateVectorStore(QdrantVectorStore.class, embeddingModel, metadataFields);
|
// return modelFactory.getOrCreateVectorStore(RedisVectorStore.class, embeddingModel, metadataFields);
|
// return modelFactory.getOrCreateVectorStore(MilvusVectorStore.class, embeddingModel, metadataFields);
|
}
|
|
}
|