package com.iailab.module.ai.service.write; import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.StrUtil; import com.iailab.framework.ai.core.enums.AiModelTypeEnum; import com.iailab.framework.ai.core.enums.AiPlatformEnum; import com.iailab.framework.ai.core.util.AiUtils; import com.iailab.framework.common.pojo.CommonResult; import com.iailab.framework.common.pojo.PageResult; import com.iailab.framework.common.util.object.BeanUtils; import com.iailab.framework.tenant.core.util.TenantUtils; import com.iailab.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO; import com.iailab.module.ai.controller.admin.write.vo.AiWritePageReqVO; import com.iailab.module.ai.dal.dataobject.model.AiChatRoleDO; import com.iailab.module.ai.dal.dataobject.model.AiModelDO; import com.iailab.module.ai.dal.dataobject.write.AiWriteDO; import com.iailab.module.ai.dal.mysql.write.AiWriteMapper; import com.iailab.module.ai.enums.AiChatRoleEnum; import com.iailab.module.ai.enums.DictTypeConstants; import com.iailab.module.ai.enums.ErrorCodeConstants; import com.iailab.module.ai.enums.write.AiWriteTypeEnum; import com.iailab.module.ai.service.model.AiChatRoleService; import com.iailab.module.ai.service.model.AiModelService; import com.iailab.module.system.api.dict.DictDataApi; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.stereotype.Service; import reactor.core.publisher.Flux; import java.util.ArrayList; import java.util.List; import java.util.Objects; import static com.iailab.framework.common.exception.util.ServiceExceptionUtil.exception; import static com.iailab.framework.common.pojo.CommonResult.error; import static com.iailab.framework.common.pojo.CommonResult.success; import static com.iailab.module.ai.enums.ErrorCodeConstants.*; /** * AI 写作 Service 实现类 * * @author xiaoxin */ @Service @Slf4j public class AiWriteServiceImpl implements AiWriteService { @Resource private AiModelService modalService; @Resource private AiChatRoleService chatRoleService; @Resource private AiWriteMapper writeMapper; @Resource private DictDataApi dictDataApi; @Override public Flux> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { // 1 获取写作模型。尝试获取写作助手角色,没有则使用默认模型 AiChatRoleDO writeRole = CollUtil.getFirst( chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName())); // 1.1 获取写作执行模型 AiModelDO model = getModel(writeRole); // 1.2 获取角色设定消息 String systemMessage = Objects.nonNull(writeRole) && StrUtil.isNotBlank(writeRole.getSystemMessage()) ? writeRole.getSystemMessage() : AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage(); // 1.3 校验平台 AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); StreamingChatModel chatModel = modalService.getChatModel(model.getId()); // 2. 插入写作信息 AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, write -> write.setUserId(userId) .setPlatform(platform.getPlatform()).setModelId(model.getId()).setModel(model.getModel())); writeMapper.insert(writeDO); // 3.1 构建 Prompt,并进行调用 Prompt prompt = buildPrompt(generateReqVO, model, systemMessage); Flux streamResponse = chatModel.stream(prompt); // 3.2 流式返回 StringBuffer contentBuffer = new StringBuffer(); return streamResponse.map(chunk -> { String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getText() : null; newContent = StrUtil.nullToDefault(newContent, ""); // 避免 null 的 情况 contentBuffer.append(newContent); // 响应结果 return success(newContent); }).doOnComplete(() -> { // 忽略租户,因为 Flux 异步无法透传租户 TenantUtils.executeIgnore(() -> writeMapper.updateById(new AiWriteDO().setId(writeDO.getId()).setGeneratedContent(contentBuffer.toString()))); }).doOnError(throwable -> { log.error("[generateWriteContent][generateReqVO({}) 发生异常]", generateReqVO, throwable); // 忽略租户,因为 Flux 异步无法透传租户 TenantUtils.executeIgnore(() -> writeMapper.updateById(new AiWriteDO().setId(writeDO.getId()).setErrorMessage(throwable.getMessage()))); }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR))); } private AiModelDO getModel(AiChatRoleDO writeRole) { AiModelDO model = null; if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) { model = modalService.getModel(writeRole.getModelId()); } if (model == null) { model = modalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType()); } // 校验模型存在、且合法 if (model == null) { throw exception(MODEL_NOT_EXISTS); } if (ObjUtil.notEqual(model.getType(), AiModelTypeEnum.CHAT.getType())) { throw exception(MODEL_USE_TYPE_ERROR); } return model; } private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiModelDO model, String systemMessage) { // 1. 构建 message 列表 List chatMessages = buildMessages(generateReqVO, systemMessage); // 2. 构建 options 对象 AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); return new Prompt(chatMessages, options); } private List buildMessages(AiWriteGenerateReqVO generateReqVO, String systemMessage) { List chatMessages = new ArrayList<>(); if (StrUtil.isNotBlank(systemMessage)) { // 1.1 角色设定 chatMessages.add(new SystemMessage(systemMessage)); } // 1.2 用户输入 chatMessages.add(new UserMessage(buildUserMessage(generateReqVO))); return chatMessages; } private String buildUserMessage(AiWriteGenerateReqVO generateReqVO) { String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat()); String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone()); String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage()); String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getLength()); // 格式化 prompt String prompt = generateReqVO.getPrompt(); if (Objects.equals(generateReqVO.getType(), AiWriteTypeEnum.WRITING.getType())) { return StrUtil.format(AiWriteTypeEnum.WRITING.getPrompt(), prompt, format, tone, language, length); } else { return StrUtil.format(AiWriteTypeEnum.REPLY.getPrompt(), generateReqVO.getOriginalContent(), prompt, format, tone, language, length); } } @Override public void deleteWrite(Long id) { // 校验存在 validateWriteExists(id); // 删除 writeMapper.deleteById(id); } private void validateWriteExists(Long id) { if (writeMapper.selectById(id) == null) { throw exception(WRITE_NOT_EXISTS); } } @Override public PageResult getWritePage(AiWritePageReqVO pageReqVO) { return writeMapper.selectPage(pageReqVO); } }