工业互联网平台2.0版本后端代码
houzhongjian
2025-05-29 41499fd3c28216c1526a72b10fa98eb8ffee78cb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
package com.iailab.module.ai.service.model;
 
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.iailab.framework.ai.core.enums.AiPlatformEnum;
import com.iailab.framework.ai.core.factory.AiModelFactory;
import com.iailab.framework.ai.core.model.midjourney.api.MidjourneyApi;
import com.iailab.framework.ai.core.model.suno.api.SunoApi;
import com.iailab.framework.common.enums.CommonStatusEnum;
import com.iailab.framework.common.pojo.PageResult;
import com.iailab.framework.common.util.object.BeanUtils;
import com.iailab.framework.mybatis.core.query.LambdaQueryWrapperX;
import com.iailab.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
import com.iailab.module.ai.controller.admin.model.vo.model.AiModelSaveReqVO;
import com.iailab.module.ai.dal.dataobject.model.AiApiKeyDO;
import com.iailab.module.ai.dal.dataobject.model.AiModelDO;
import com.iailab.module.ai.dal.mysql.model.AiChatMapper;
import jakarta.annotation.Resource;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated;
 
import java.util.List;
import java.util.Map;
 
import static com.iailab.framework.common.exception.util.ServiceExceptionUtil.exception;
import static com.iailab.module.ai.enums.ErrorCodeConstants.*;
 
/**
 * AI 模型 Service 实现类
 *
 * @author fansili
 */
@Service
@Validated
public class AiModelServiceImpl implements AiModelService {
 
    @Resource
    private AiApiKeyService apiKeyService;
 
    @Resource
    private AiChatMapper modelMapper;
 
    @Resource
    private AiModelFactory modelFactory;
 
    @Override
    public Long createModel(AiModelSaveReqVO createReqVO) {
        // 1. 校验
        AiPlatformEnum.validatePlatform(createReqVO.getPlatform());
        apiKeyService.validateApiKey(createReqVO.getKeyId());
 
        //模型名称不能重复
        List<AiModelDO> aiModelDOS = modelMapper.selectList(new LambdaQueryWrapperX<AiModelDO>().eq(AiModelDO::getName, createReqVO.getName()));
        if (aiModelDOS.size() > 0) {
            throw exception(MODEL_NAME_EXISTS);
        }
 
        // 2. 插入
        AiModelDO model = BeanUtils.toBean(createReqVO, AiModelDO.class);
        modelMapper.insert(model);
        return model.getId();
    }
 
    @Override
    public void updateModel(AiModelSaveReqVO updateReqVO) {
        // 1. 校验
        validateModelExists(updateReqVO.getId());
        AiPlatformEnum.validatePlatform(updateReqVO.getPlatform());
        apiKeyService.validateApiKey(updateReqVO.getKeyId());
 
        //模型名称不能重复
        List<AiModelDO> aiModelDOS = modelMapper.selectList(new LambdaQueryWrapperX<AiModelDO>()
                .eq(AiModelDO::getName, updateReqVO.getName())
                .ne(AiModelDO::getId, updateReqVO.getId()));
        if (aiModelDOS.size() > 0) {
            throw exception(MODEL_NAME_EXISTS);
        }
 
        // 2. 更新
        AiModelDO updateObj = BeanUtils.toBean(updateReqVO, AiModelDO.class);
        modelMapper.updateById(updateObj);
    }
 
    @Override
    public void deleteModel(Long id) {
        // 校验存在
        validateModelExists(id);
        // 删除
        modelMapper.deleteById(id);
    }
 
    private AiModelDO validateModelExists(Long id) {
        AiModelDO model = modelMapper.selectById(id);
        if (modelMapper.selectById(id) == null) {
            throw exception(MODEL_NOT_EXISTS);
        }
        return model;
    }
 
    @Override
    public AiModelDO getModel(Long id) {
        return modelMapper.selectById(id);
    }
 
    @Override
    public AiModelDO getModelByName(String name) {
        return modelMapper.selectOne(new LambdaQueryWrapper<AiModelDO>().eq(AiModelDO::getName, name));
    }
 
    @Override
    public AiModelDO getRequiredDefaultModel(Integer type) {
        AiModelDO model = modelMapper.selectFirstByStatus(type, CommonStatusEnum.ENABLE.getStatus());
        if (model == null) {
            throw exception(MODEL_DEFAULT_NOT_EXISTS);
        }
        return model;
    }
 
    @Override
    public PageResult<AiModelDO> getModelPage(AiModelPageReqVO pageReqVO) {
        return modelMapper.selectPage(pageReqVO);
    }
 
    @Override
    public AiModelDO validateModel(Long id) {
        AiModelDO model = validateModelExists(id);
        if (CommonStatusEnum.isDisable(model.getStatus())) {
            throw exception(MODEL_DISABLE);
        }
        return model;
    }
 
    @Override
    public List<AiModelDO> getModelListByStatusAndType(Integer status, Integer type, String platform) {
        return modelMapper.selectListByStatusAndType(status, type, platform);
    }
 
    // ========== 与 Spring AI 集成 ==========
 
    @Override
    public ChatModel getChatModel(Long id) {
        AiModelDO model = validateModel(id);
        AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
        AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
        return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl());
    }
 
    @Override
    public ImageModel getImageModel(Long id) {
        AiModelDO model = validateModel(id);
        AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
        AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
        return modelFactory.getOrCreateImageModel(platform, apiKey.getApiKey(), apiKey.getUrl());
    }
 
    @Override
    public MidjourneyApi getMidjourneyApi(Long id) {
        AiModelDO model = validateModel(id);
        AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
        return modelFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl());
    }
 
    @Override
    public SunoApi getSunoApi() {
        AiApiKeyDO apiKey = apiKeyService.getRequiredDefaultApiKey(
                AiPlatformEnum.SUNO.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
        return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl());
    }
 
    @Override
    public VectorStore getOrCreateVectorStore(Long id, Map<String, Class<?>> metadataFields) {
        // 获取模型 + 密钥
        AiModelDO model = validateModel(id);
        AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
        AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
 
        // 创建或获取 EmbeddingModel 对象
        EmbeddingModel embeddingModel = modelFactory.getOrCreateEmbeddingModel(
                platform, apiKey.getApiKey(), apiKey.getUrl(), model.getModel());
 
        // 创建或获取 VectorStore 对象
         return modelFactory.getOrCreateVectorStore(SimpleVectorStore.class, embeddingModel, metadataFields);
//         return modelFactory.getOrCreateVectorStore(QdrantVectorStore.class, embeddingModel, metadataFields);
//         return modelFactory.getOrCreateVectorStore(RedisVectorStore.class, embeddingModel, metadataFields);
//         return modelFactory.getOrCreateVectorStore(MilvusVectorStore.class, embeddingModel, metadataFields);
    }
 
}