工业互联网平台2.0版本后端代码
houzhongjian
2025-05-29 41499fd3c28216c1526a72b10fa98eb8ffee78cb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
/*
 * Copyright 2023-2024 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
 
package com.iailab.framework.ai.core.model.siliconflow;
 
import io.micrometer.observation.ObservationRegistry;
import lombok.Setter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.image.*;
import org.springframework.ai.image.observation.DefaultImageModelObservationConvention;
import org.springframework.ai.image.observation.ImageModelObservationContext;
import org.springframework.ai.image.observation.ImageModelObservationConvention;
import org.springframework.ai.image.observation.ImageModelObservationDocumentation;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.openai.OpenAiImageModel;
import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.lang.Nullable;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
 
import java.util.List;
 
/**
 * 硅基流动 {@link ImageModel} 实现类
 *
 * 参考 {@link OpenAiImageModel} 实现
 *
 * @author zzt
 */
public class SiliconFlowImageModel implements ImageModel {
 
    private static final Logger logger = LoggerFactory.getLogger(SiliconFlowImageModel.class);
 
    private static final ImageModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultImageModelObservationConvention();
 
    private final SiliconFlowImageOptions defaultOptions;
 
    private final RetryTemplate retryTemplate;
 
    private final SiliconFlowImageApi siliconFlowImageApi;
 
    private final ObservationRegistry observationRegistry;
 
    @Setter
    private ImageModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
 
    public SiliconFlowImageModel(SiliconFlowImageApi siliconFlowImageApi) {
        this(siliconFlowImageApi, SiliconFlowImageOptions.builder().build(), RetryUtils.DEFAULT_RETRY_TEMPLATE);
    }
 
    public SiliconFlowImageModel(SiliconFlowImageApi siliconFlowImageApi, SiliconFlowImageOptions options, RetryTemplate retryTemplate) {
        this(siliconFlowImageApi, options, retryTemplate, ObservationRegistry.NOOP);
    }
 
    public SiliconFlowImageModel(SiliconFlowImageApi siliconFlowImageApi, SiliconFlowImageOptions options, RetryTemplate retryTemplate,
                                 ObservationRegistry observationRegistry) {
        Assert.notNull(siliconFlowImageApi, "OpenAiImageApi must not be null");
        Assert.notNull(options, "options must not be null");
        Assert.notNull(retryTemplate, "retryTemplate must not be null");
        Assert.notNull(observationRegistry, "observationRegistry must not be null");
        this.siliconFlowImageApi = siliconFlowImageApi;
        this.defaultOptions = options;
        this.retryTemplate = retryTemplate;
        this.observationRegistry = observationRegistry;
    }
 
    @Override
    public ImageResponse call(ImagePrompt imagePrompt) {
        SiliconFlowImageOptions requestImageOptions = mergeOptions(imagePrompt.getOptions(), this.defaultOptions);
        SiliconFlowImageApi.SiliconflowImageRequest imageRequest = createRequest(imagePrompt, requestImageOptions);
 
        var observationContext = ImageModelObservationContext.builder()
            .imagePrompt(imagePrompt)
            .provider(SiliconFlowApiConstants.PROVIDER_NAME)
            .requestOptions(imagePrompt.getOptions())
            .build();
 
        return ImageModelObservationDocumentation.IMAGE_MODEL_OPERATION
            .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
                    this.observationRegistry)
            .observe(() -> {
                ResponseEntity<OpenAiImageApi.OpenAiImageResponse> imageResponseEntity = this.retryTemplate
                    .execute(ctx -> this.siliconFlowImageApi.createImage(imageRequest));
 
                ImageResponse imageResponse = convertResponse(imageResponseEntity, imageRequest);
 
                observationContext.setResponse(imageResponse);
 
                return imageResponse;
            });
    }
 
    private SiliconFlowImageApi.SiliconflowImageRequest createRequest(ImagePrompt imagePrompt,
                                                                      SiliconFlowImageOptions requestImageOptions) {
        String instructions = imagePrompt.getInstructions().get(0).getText();
 
        SiliconFlowImageApi.SiliconflowImageRequest imageRequest = new SiliconFlowImageApi.SiliconflowImageRequest(instructions,
                SiliconFlowApiConstants.DEFAULT_IMAGE_MODEL);
 
        return ModelOptionsUtils.merge(requestImageOptions, imageRequest, SiliconFlowImageApi.SiliconflowImageRequest.class);
    }
 
    private ImageResponse convertResponse(ResponseEntity<OpenAiImageApi.OpenAiImageResponse> imageResponseEntity,
                                          SiliconFlowImageApi.SiliconflowImageRequest siliconflowImageRequest) {
        OpenAiImageApi.OpenAiImageResponse imageApiResponse = imageResponseEntity.getBody();
        if (imageApiResponse == null) {
            logger.warn("No image response returned for request: {}", siliconflowImageRequest);
            return new ImageResponse(List.of());
        }
 
        List<ImageGeneration> imageGenerationList = imageApiResponse.data()
            .stream()
            .map(entry -> new ImageGeneration(new Image(entry.url(), entry.b64Json()),
                    new OpenAiImageGenerationMetadata(entry.revisedPrompt())))
            .toList();
 
        ImageResponseMetadata openAiImageResponseMetadata = new ImageResponseMetadata(imageApiResponse.created());
        return new ImageResponse(imageGenerationList, openAiImageResponseMetadata);
    }
 
    private SiliconFlowImageOptions mergeOptions(@Nullable ImageOptions runtimeOptions, SiliconFlowImageOptions defaultOptions) {
        var runtimeOptionsForProvider = ModelOptionsUtils.copyToTarget(runtimeOptions, ImageOptions.class,
                SiliconFlowImageOptions.class);
 
        if (runtimeOptionsForProvider == null) {
            return defaultOptions;
        }
 
        return SiliconFlowImageOptions.builder()
                // Handle portable image options
                .model(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getModel(), defaultOptions.getModel()))
                .batchSize(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getN(), defaultOptions.getN()))
                .width(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getWidth(), defaultOptions.getWidth()))
                .height(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getHeight(), defaultOptions.getHeight()))
                // Handle SiliconFlow specific image options
                .negativePrompt(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getNegativePrompt(), defaultOptions.getNegativePrompt()))
                .numInferenceSteps(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getNumInferenceSteps(), defaultOptions.getNumInferenceSteps()))
                .guidanceScale(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getGuidanceScale(), defaultOptions.getGuidanceScale()))
                .seed(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getSeed(), defaultOptions.getSeed()))
                .build();
    }
}