dengzedong
2025-02-21 f7e4a8c81cb019d9aef5ff55ddedf8083943ca8b
提交 | 用户 | 时间
449017 1 package com.iailab.module.model.mpk.controller.admin;
D 2
a8c6a6 3 import cn.hutool.cache.CacheUtil;
D 4 import cn.hutool.cache.impl.FIFOCache;
f7e4a8 5 import cn.hutool.core.io.FileUtil;
449017 6 import com.alibaba.fastjson.JSON;
D 7 import com.alibaba.fastjson.JSONArray;
f7e4a8 8 import com.iail.IAILMDK;
D 9 import com.iail.bean.FieldSet;
10 import com.iail.bean.Property;
11 import com.iail.bean.SelectItem;
12 import com.iail.model.IAILModel;
3e18d4 13 import com.iail.utils.RSAUtils;
449017 14 import com.iailab.framework.common.exception.enums.GlobalErrorCodeConstants;
D 15 import com.iailab.framework.common.pojo.CommonResult;
558ffc 16 import com.iailab.framework.tenant.core.context.TenantContextHolder;
1fea3e 17 import com.iailab.module.model.mpk.common.MdkConstant;
912a1e 18 import com.iailab.module.model.mpk.common.utils.DllUtils;
449017 19 import com.iailab.module.model.mpk.common.utils.Readtxt;
D 20 import com.iailab.module.model.mpk.dto.MdkDTO;
3e18d4 21 import com.iailab.module.model.mpk.dto.MdkRunDTO;
a8c6a6 22 import com.iailab.module.model.mpk.dto.MethodSettingDTO;
449017 23 import io.swagger.v3.oas.annotations.Operation;
912a1e 24 import lombok.extern.slf4j.Slf4j;
f7e4a8 25 import org.apache.commons.io.IOUtils;
1fea3e 26 import org.springframework.beans.factory.annotation.Value;
449017 27 import org.springframework.util.CollectionUtils;
D 28 import org.springframework.web.bind.annotation.*;
29 import org.springframework.web.multipart.MultipartFile;
30
f7e4a8 31 import javax.servlet.http.HttpServletResponse;
1fea3e 32 import java.io.File;
f7e4a8 33 import java.io.IOException;
3e18d4 34 import java.lang.reflect.Method;
1fea3e 35 import java.net.URLClassLoader;
f7e4a8 36 import java.net.URLEncoder;
D 37 import java.nio.file.Files;
38 import java.util.*;
39 import java.util.stream.Collectors;
449017 40
D 41 import static com.iailab.framework.common.pojo.CommonResult.error;
42 import static com.iailab.framework.common.pojo.CommonResult.success;
43
44 /**
45  * @author PanZhibao
46  * @Description
47  * @createTime 2024年08月08日
48  */
49 @RestController
912a1e 50 @Slf4j
449017 51 @RequestMapping("/model/mpk/api")
D 52 public class MdkController {
1fea3e 53     @Value("${mpk.bak-file-path}")
D 54     private String mpkBakFilePath;
55
f7e4a8 56     // 先进先出缓存 临时保存导入的数据
a8c6a6 57     private static FIFOCache<String, String> cache = CacheUtil.newFIFOCache(100);
1fea3e 58
3e18d4 59     /**
D 60      * @description: 模型测试运行
61      * @author: dzd
62      * @date: 2024/10/14 15:26
63      **/
64     @PostMapping("test")
65     public CommonResult<String> test(@RequestBody MdkDTO dto) {
558ffc 66         Long tenantId = TenantContextHolder.getTenantId();
D 67         // 备份文件 租户隔离
68         String mpkTenantBakFilePath = mpkBakFilePath + File.separator + tenantId;
69
912a1e 70         Class<?> clazz;
D 71         URLClassLoader classLoader;
1fea3e 72         try {
558ffc 73             File jarFile = new File(mpkTenantBakFilePath + File.separator + MdkConstant.JAR + File.separator + dto.getPyName() + ".jar");
1fea3e 74             if (!jarFile.exists()) {
D 75                 throw new RuntimeException("jar包不存在,请先生成代码。jarPath:" + jarFile.getAbsolutePath());
76             }
558ffc 77             File dllFile = new File(mpkTenantBakFilePath + File.separator + MdkConstant.DLL + File.separator + dto.getPyName() + ".dll");
1fea3e 78             if (!dllFile.exists()) {
912a1e 79                 throw new RuntimeException("dll文件不存在,请先生成代码。dllPath:" + dllFile.getAbsolutePath());
1fea3e 80             }
D 81             // 加载jar包
912a1e 82             classLoader = DllUtils.loadJar(jarFile.getAbsolutePath());
D 83             // 实现类
84             clazz = classLoader.loadClass(dto.getClassName());
85             // 加载dll到实现类
86             DllUtils.loadDll(clazz,dllFile.getAbsolutePath());
1fea3e 87         } catch (Exception e) {
912a1e 88             e.printStackTrace();
1fea3e 89             throw new RuntimeException("加载运行环境失败。");
D 90         }
91
449017 92         System.out.println("runTime=" + System.currentTimeMillis());
D 93         try {
a8c6a6 94             List<String> uuids = dto.getUuids();
449017 95
a8c6a6 96             int paramLength = dto.getHasModel() ? uuids.size() + 2 : uuids.size() + 1;
449017 97             Object[] paramsValueArray = new Object[paramLength];
D 98             Class<?>[] paramsArray = new Class[paramLength];
99
100             try {
a8c6a6 101                 for (int i = 0; i < uuids.size(); i++) {
D 102                     String uuid = uuids.get(i);
103                     if (!cache.containsKey(uuid)) {
104                         return error(GlobalErrorCodeConstants.BAD_REQUEST.getCode(),"请重新导入模型参数");
105                     }
106                     JSONArray jsonArray = JSON.parseArray(cache.get(uuid));
449017 107                     double[][] data = new double[jsonArray.size()][jsonArray.getJSONArray(0).size()];
D 108                     for (int j = 0; j < jsonArray.size(); j++) {
109                         for (int k = 0; k < jsonArray.getJSONArray(j).size(); k++) {
110                             data[j][k] = jsonArray.getJSONArray(j).getDoubleValue(k);
111                         }
112                     }
113                     paramsValueArray[i] = data;
114                     paramsArray[i] = double[][].class;
115                 }
116             } catch (Exception e) {
117                 e.printStackTrace();
a8c6a6 118                 return error(GlobalErrorCodeConstants.BAD_REQUEST.getCode(),"模型参数错误,请检查!");
449017 119             }
D 120
a8c6a6 121             try {
4f4b05 122                 if (dto.getModelSettings().stream().noneMatch(e -> e.getSettingKey().equals(MdkConstant.PY_FILE_KEY))) {
D 123                     return error(GlobalErrorCodeConstants.BAD_REQUEST.getCode(),"模型设置参数缺少必要信息【" + MdkConstant.PY_FILE_KEY +  "】,请重新上传模型!");
124                 }
125
a8c6a6 126                 if (dto.getHasModel()) {
D 127                     paramsValueArray[uuids.size()] = dto.getModel();
128                     paramsValueArray[uuids.size() + 1] = handleModelSettings(dto.getModelSettings());
129                     paramsArray[uuids.size()] = HashMap.class;
130                     paramsArray[uuids.size() + 1] = HashMap.class;
131                 }else {
132                     paramsValueArray[uuids.size()] = handleModelSettings(dto.getModelSettings());
133                     paramsArray[uuids.size()] = HashMap.class;
134                 }
135             } catch (Exception e) {
136                 e.printStackTrace();
137                 return error(GlobalErrorCodeConstants.BAD_REQUEST.getCode(),"模型设置错误,请检查!");
449017 138             }
D 139
912a1e 140             HashMap result = (HashMap) clazz.getDeclaredMethod(dto.getMethodName(), paramsArray).invoke(clazz.newInstance(), paramsValueArray);
449017 141             return success(JSON.toJSONString(result));
D 142         } catch (Exception ex) {
143             ex.printStackTrace();
144             return error(GlobalErrorCodeConstants.INTERNAL_SERVER_ERROR.getCode(),"运行异常");
145         } finally {
912a1e 146             if (classLoader != null) {
c12dae 147                 DllUtils.unloadDll(classLoader);
D 148                 DllUtils.unloadJar(classLoader);
912a1e 149             }
449017 150             System.gc();
D 151         }
a8c6a6 152     }
D 153
f7e4a8 154     private IAILModel createModelBean(MdkDTO dto) {
D 155         IAILModel modelBean = new IAILModel();
156
157         //ParamPathList
158         List<String> paramPathList = new ArrayList<>();
159         List<String> paramNameList = new ArrayList<>();
160
161         for (Map.Entry<String, Object> entry : dto.getModel().entrySet()) {
162             paramNameList.add(entry.getKey());
163             paramPathList.add(entry.getValue().toString());
164         }
165         modelBean.setParamNameList(paramNameList);
166         modelBean.setParamPathList(paramPathList);
167         //ClassName MethodName
168         modelBean.setClassName(dto.getClassName());
169         modelBean.setMethodName(dto.getMethodName());
170         //ParamsArray
171         int paramLength = dto.getHasModel() ? dto.getDataLength() + 2 : dto.getDataLength() + 1;
172         Class<?>[] paramsArray = new Class[paramLength];
173
174         for (int i = 0; i < dto.getDataLength(); i++) {
175             paramsArray[i] = double[][].class;
176         }
177
178         if (dto.getHasModel()) {
179             paramsArray[dto.getDataLength()] = HashMap.class;
180             paramsArray[dto.getDataLength() + 1] = HashMap.class;
181         }else {
182             paramsArray[dto.getDataLength()] = HashMap.class;
183         }
184         modelBean.setParamsArray(paramsArray);
185         //LoadFieldSetList
186         List<FieldSet> loadFieldSetList = new ArrayList<>();
187         FieldSet fieldSet = new FieldSet();
188         fieldSet.setFieldName("");
189         List<Property> propertyList = new ArrayList<>();
190         for (MethodSettingDTO modelSetting : dto.getPredModelSettings()) {
191             Property property = new Property();
192             property.setKey(modelSetting.getSettingKey());
193             property.setName(modelSetting.getName());
194             property.setType(modelSetting.getType());
195             property.setValueType(modelSetting.getValueType());
196             property.setMin(modelSetting.getMin() == null ? "" : modelSetting.getMin().toString());
197             property.setMax(modelSetting.getMax() == null ? "" : modelSetting.getMax().toString());
198             property.setSelectItemList(CollectionUtils.isEmpty(modelSetting.getSettingSelects()) ? null : modelSetting.getSettingSelects().stream().map(e -> new SelectItem(e.getSelectKey(),e.getName())).collect(Collectors.toList()));
199             property.setValue(modelSetting.getValue());
200             property.setFlow(false);
201             propertyList.add(property);
202         }
203         fieldSet.setPropertyList(propertyList);
204         loadFieldSetList.add(fieldSet);
205         modelBean.setLoadFieldSetList(loadFieldSetList);
206         //SettingConfigMap
207         Map<String, Object> settingConfigMap = new HashMap<String, Object>();
208         List<com.iail.bean.Value> settingKeyList = new ArrayList<com.iail.bean.Value>();
209         Map<String, Object> settingMap = new HashMap<String, Object>();
210         for (MethodSettingDTO modelSetting : dto.getModelSettings()) {
211             settingKeyList.add(new com.iail.bean.Value(modelSetting.getSettingKey(),modelSetting.getSettingKey()));
212             settingConfigMap.put("settingKeyList", settingKeyList);
213             settingConfigMap.put("settingMap", handleModelSettings(dto.getModelSettings()));
214         }
215         modelBean.setSettingConfigMap(settingConfigMap);
216         //DataMap
217         modelBean.setDataMap(dto.getModelResult());
218         //ResultKey
219         modelBean.setResultKey(dto.getResultKey());
220         //ResultKey
221         modelBean.setVersion("1.0.0");
222
223
224         return modelBean;
225     }
226
227     @PostMapping("saveModel")
228     public void saveModel(@RequestBody MdkDTO dto, HttpServletResponse response) {
229         IAILModel modelBean = createModelBean(dto);
230
231         try {
232             //临时文件夹
233             File tempFile = null;
234             try {
235                 tempFile = Files.createTempFile(dto.getPyName(),".miail").toFile();
236                 log.info("生成临时文件," + tempFile.getAbsolutePath());
237             } catch (IOException e) {
238                 throw new RuntimeException("创建临时文件异常",e);
239             }
240
241
242
243             try {
244                 IAILMDK.saveModel(tempFile, modelBean);
245             } catch (Exception e) {
246                 throw new RuntimeException("IAILMDK.saveModel异常",e);
247             }
248
249             byte[] data = FileUtil.readBytes(tempFile);
250             response.reset();
251             response.setHeader("Content-Disposition", "attachment; filename=\"" + URLEncoder.encode(tempFile.getName(), "UTF-8") + "\"");
252             response.addHeader("Content-Length", "" + data.length);
253             response.setContentType("application/octet-stream; charset=UTF-8");
254
255             IOUtils.write(data, response.getOutputStream());
256         } catch (Exception e) {
257             throw new RuntimeException("代码生成异常",e);
258         }
259     }
260
a8c6a6 261     private HashMap<String, Object> handleModelSettings(List<MethodSettingDTO> modelSettings) {
D 262         HashMap<String, Object> resultMap = null;
263         try {
264             resultMap = new HashMap<>(modelSettings.size());
265             for (MethodSettingDTO modelSetting : modelSettings) {
266                 switch (modelSetting.getValueType()) {
267                     case "int":
268                         resultMap.put(modelSetting.getSettingKey(), Integer.valueOf(modelSetting.getSettingValue()));
269                         break;
270                     case "string":
271                         resultMap.put(modelSetting.getSettingKey(), modelSetting.getSettingValue());
272                         break;
273                     case "decimal":
274                         resultMap.put(modelSetting.getSettingKey(), Double.valueOf(modelSetting.getSettingValue()));
275                         break;
276                     case "decimalArray":
277                         JSONArray jsonArray = JSON.parseArray(modelSetting.getSettingValue());
278                         double[] doubles = new double[jsonArray.size()];
279                         for (int i = 0; i < jsonArray.size(); i++) {
280                             doubles[i] = Double.valueOf(String.valueOf(jsonArray.get(i)));
281                         }
282                         resultMap.put(modelSetting.getSettingKey(), doubles);
283                         break;
284                 }
285             }
286         } catch (NumberFormatException e) {
287             throw new RuntimeException("模型参数有误,请检查!!!");
288         }
289         return resultMap;
449017 290     }
D 291
3e18d4 292     /**
D 293      * @description: 模型运行
294      * @author: dzd
295      * @date: 2024/10/14 15:26
296      **/
297     @PostMapping("run")
298     public CommonResult<String> run(@RequestBody MdkRunDTO dto) {
299         if (RSAUtils.checkLisenceBean().getCode() != 1) {
300             log.error("Lisence 不可用!");
301             return CommonResult.error(GlobalErrorCodeConstants.INTERNAL_SERVER_ERROR.getCode(),"Lisence 不可用!");
302         } else {
303             try {
304                 URLClassLoader classLoader = DllUtils.getClassLoader(dto.getProjectId());
305                 if (null == classLoader) {
306                     return CommonResult.error(GlobalErrorCodeConstants.ERROR_CONFIGURATION.getCode(),"请先发布项目!");
307                 }
308
309                 Class<?> clazz = classLoader.loadClass(dto.getClassName());
310                 Method method = clazz.getMethod(dto.getMethodName(), dto.getParamsClassArray());
311                 HashMap invoke = (HashMap) method.invoke(clazz.newInstance(), dto.getParamsValueArray());
312
313                 // todo 将结果存入数据库
314
315             } catch (Exception e) {
316                 log.error("模型运行失败",e);
317                 return CommonResult.error(GlobalErrorCodeConstants.INTERNAL_SERVER_ERROR.getCode(),"模型运行失败!");
318             }
319         }
b131f0 320         return CommonResult.success("");
3e18d4 321     }
D 322
449017 323     @PostMapping("/import")
D 324     @Operation(summary = "导入参数")
a8c6a6 325     public CommonResult<List<HashMap<String,Object>>> importExcel(@RequestParam("file") MultipartFile file) throws Exception {
449017 326         List<double[][]> datas = Readtxt.readMethodExcel(file);
a8c6a6 327         List<HashMap<String,Object>> result = new ArrayList<>();
449017 328         if (!CollectionUtils.isEmpty(datas)) {
D 329             for (double[][] data : datas) {
330                 if (data.length > 0) {
a8c6a6 331                     HashMap<String,Object> map = new HashMap<>();
D 332                     String uuid = UUID.randomUUID().toString();
333                     map.put("uuid",uuid);
334                     map.put("data",JSON.toJSONString(data));
335                     cache.put(uuid,JSON.toJSONString(data));
336                     result.add(map);
449017 337                 }
D 338             }
339         }
340         return success(result);
1fea3e 341     }
449017 342 }