package com.iailab.module.ai.service.write;
|
|
import cn.hutool.core.collection.CollUtil;
|
import cn.hutool.core.util.ObjUtil;
|
import cn.hutool.core.util.StrUtil;
|
import com.iailab.framework.ai.core.enums.AiModelTypeEnum;
|
import com.iailab.framework.ai.core.enums.AiPlatformEnum;
|
import com.iailab.framework.ai.core.util.AiUtils;
|
import com.iailab.framework.common.pojo.CommonResult;
|
import com.iailab.framework.common.pojo.PageResult;
|
import com.iailab.framework.common.util.object.BeanUtils;
|
import com.iailab.framework.tenant.core.util.TenantUtils;
|
import com.iailab.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
|
import com.iailab.module.ai.controller.admin.write.vo.AiWritePageReqVO;
|
import com.iailab.module.ai.dal.dataobject.model.AiChatRoleDO;
|
import com.iailab.module.ai.dal.dataobject.model.AiModelDO;
|
import com.iailab.module.ai.dal.dataobject.write.AiWriteDO;
|
import com.iailab.module.ai.dal.mysql.write.AiWriteMapper;
|
import com.iailab.module.ai.enums.AiChatRoleEnum;
|
import com.iailab.module.ai.enums.DictTypeConstants;
|
import com.iailab.module.ai.enums.ErrorCodeConstants;
|
import com.iailab.module.ai.enums.write.AiWriteTypeEnum;
|
import com.iailab.module.ai.service.model.AiChatRoleService;
|
import com.iailab.module.ai.service.model.AiModelService;
|
import com.iailab.module.system.api.dict.DictDataApi;
|
import jakarta.annotation.Resource;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.ai.chat.messages.Message;
|
import org.springframework.ai.chat.messages.SystemMessage;
|
import org.springframework.ai.chat.messages.UserMessage;
|
import org.springframework.ai.chat.model.ChatResponse;
|
import org.springframework.ai.chat.model.StreamingChatModel;
|
import org.springframework.ai.chat.prompt.ChatOptions;
|
import org.springframework.ai.chat.prompt.Prompt;
|
import org.springframework.stereotype.Service;
|
import reactor.core.publisher.Flux;
|
|
import java.util.ArrayList;
|
import java.util.List;
|
import java.util.Objects;
|
|
import static com.iailab.framework.common.exception.util.ServiceExceptionUtil.exception;
|
import static com.iailab.framework.common.pojo.CommonResult.error;
|
import static com.iailab.framework.common.pojo.CommonResult.success;
|
import static com.iailab.module.ai.enums.ErrorCodeConstants.*;
|
|
/**
|
* AI 写作 Service 实现类
|
*
|
* @author xiaoxin
|
*/
|
@Service
|
@Slf4j
|
public class AiWriteServiceImpl implements AiWriteService {
|
|
@Resource
|
private AiModelService modalService;
|
@Resource
|
private AiChatRoleService chatRoleService;
|
|
@Resource
|
private AiWriteMapper writeMapper;
|
|
@Resource
|
private DictDataApi dictDataApi;
|
|
@Override
|
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
|
// 1 获取写作模型。尝试获取写作助手角色,没有则使用默认模型
|
AiChatRoleDO writeRole = CollUtil.getFirst(
|
chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
|
// 1.1 获取写作执行模型
|
AiModelDO model = getModel(writeRole);
|
// 1.2 获取角色设定消息
|
String systemMessage = Objects.nonNull(writeRole) && StrUtil.isNotBlank(writeRole.getSystemMessage())
|
? writeRole.getSystemMessage() : AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
|
// 1.3 校验平台
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
StreamingChatModel chatModel = modalService.getChatModel(model.getId());
|
|
// 2. 插入写作信息
|
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, write -> write.setUserId(userId)
|
.setPlatform(platform.getPlatform()).setModelId(model.getId()).setModel(model.getModel()));
|
writeMapper.insert(writeDO);
|
|
// 3.1 构建 Prompt,并进行调用
|
Prompt prompt = buildPrompt(generateReqVO, model, systemMessage);
|
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
|
|
// 3.2 流式返回
|
StringBuffer contentBuffer = new StringBuffer();
|
return streamResponse.map(chunk -> {
|
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getText() : null;
|
newContent = StrUtil.nullToDefault(newContent, ""); // 避免 null 的 情况
|
contentBuffer.append(newContent);
|
// 响应结果
|
return success(newContent);
|
}).doOnComplete(() -> {
|
// 忽略租户,因为 Flux 异步无法透传租户
|
TenantUtils.executeIgnore(() ->
|
writeMapper.updateById(new AiWriteDO().setId(writeDO.getId()).setGeneratedContent(contentBuffer.toString())));
|
}).doOnError(throwable -> {
|
log.error("[generateWriteContent][generateReqVO({}) 发生异常]", generateReqVO, throwable);
|
// 忽略租户,因为 Flux 异步无法透传租户
|
TenantUtils.executeIgnore(() ->
|
writeMapper.updateById(new AiWriteDO().setId(writeDO.getId()).setErrorMessage(throwable.getMessage())));
|
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
|
}
|
|
private AiModelDO getModel(AiChatRoleDO writeRole) {
|
AiModelDO model = null;
|
if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
|
model = modalService.getModel(writeRole.getModelId());
|
}
|
if (model == null) {
|
model = modalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType());
|
}
|
// 校验模型存在、且合法
|
if (model == null) {
|
throw exception(MODEL_NOT_EXISTS);
|
}
|
if (ObjUtil.notEqual(model.getType(), AiModelTypeEnum.CHAT.getType())) {
|
throw exception(MODEL_USE_TYPE_ERROR);
|
}
|
return model;
|
}
|
|
private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiModelDO model, String systemMessage) {
|
// 1. 构建 message 列表
|
List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
|
// 2. 构建 options 对象
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
return new Prompt(chatMessages, options);
|
}
|
|
private List<Message> buildMessages(AiWriteGenerateReqVO generateReqVO, String systemMessage) {
|
List<Message> chatMessages = new ArrayList<>();
|
if (StrUtil.isNotBlank(systemMessage)) {
|
// 1.1 角色设定
|
chatMessages.add(new SystemMessage(systemMessage));
|
}
|
// 1.2 用户输入
|
chatMessages.add(new UserMessage(buildUserMessage(generateReqVO)));
|
return chatMessages;
|
}
|
|
private String buildUserMessage(AiWriteGenerateReqVO generateReqVO) {
|
String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
|
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());
|
String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());
|
String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getLength());
|
// 格式化 prompt
|
String prompt = generateReqVO.getPrompt();
|
if (Objects.equals(generateReqVO.getType(), AiWriteTypeEnum.WRITING.getType())) {
|
return StrUtil.format(AiWriteTypeEnum.WRITING.getPrompt(), prompt, format, tone, language, length);
|
} else {
|
return StrUtil.format(AiWriteTypeEnum.REPLY.getPrompt(), generateReqVO.getOriginalContent(), prompt, format, tone, language, length);
|
}
|
}
|
|
@Override
|
public void deleteWrite(Long id) {
|
// 校验存在
|
validateWriteExists(id);
|
// 删除
|
writeMapper.deleteById(id);
|
}
|
|
private void validateWriteExists(Long id) {
|
if (writeMapper.selectById(id) == null) {
|
throw exception(WRITE_NOT_EXISTS);
|
}
|
}
|
|
@Override
|
public PageResult<AiWriteDO> getWritePage(AiWritePageReqVO pageReqVO) {
|
return writeMapper.selectPage(pageReqVO);
|
}
|
|
}
|