提交 | 用户 | 时间
|
449017
|
1 |
package com.iailab.module.model.mpk.controller.admin; |
D |
2 |
|
a8c6a6
|
3 |
import cn.hutool.cache.CacheUtil; |
D |
4 |
import cn.hutool.cache.impl.FIFOCache; |
f7e4a8
|
5 |
import cn.hutool.core.io.FileUtil; |
449017
|
6 |
import com.alibaba.fastjson.JSON; |
D |
7 |
import com.alibaba.fastjson.JSONArray; |
f7e4a8
|
8 |
import com.iail.IAILMDK; |
D |
9 |
import com.iail.bean.FieldSet; |
|
10 |
import com.iail.bean.Property; |
|
11 |
import com.iail.bean.SelectItem; |
|
12 |
import com.iail.model.IAILModel; |
3e18d4
|
13 |
import com.iail.utils.RSAUtils; |
449017
|
14 |
import com.iailab.framework.common.exception.enums.GlobalErrorCodeConstants; |
D |
15 |
import com.iailab.framework.common.pojo.CommonResult; |
558ffc
|
16 |
import com.iailab.framework.tenant.core.context.TenantContextHolder; |
1fea3e
|
17 |
import com.iailab.module.model.mpk.common.MdkConstant; |
912a1e
|
18 |
import com.iailab.module.model.mpk.common.utils.DllUtils; |
449017
|
19 |
import com.iailab.module.model.mpk.common.utils.Readtxt; |
D |
20 |
import com.iailab.module.model.mpk.dto.MdkDTO; |
3e18d4
|
21 |
import com.iailab.module.model.mpk.dto.MdkRunDTO; |
a8c6a6
|
22 |
import com.iailab.module.model.mpk.dto.MethodSettingDTO; |
449017
|
23 |
import io.swagger.v3.oas.annotations.Operation; |
912a1e
|
24 |
import lombok.extern.slf4j.Slf4j; |
f7e4a8
|
25 |
import org.apache.commons.io.IOUtils; |
1fea3e
|
26 |
import org.springframework.beans.factory.annotation.Value; |
449017
|
27 |
import org.springframework.util.CollectionUtils; |
D |
28 |
import org.springframework.web.bind.annotation.*; |
|
29 |
import org.springframework.web.multipart.MultipartFile; |
|
30 |
|
f7e4a8
|
31 |
import javax.servlet.http.HttpServletResponse; |
1fea3e
|
32 |
import java.io.File; |
f7e4a8
|
33 |
import java.io.IOException; |
3e18d4
|
34 |
import java.lang.reflect.Method; |
1fea3e
|
35 |
import java.net.URLClassLoader; |
f7e4a8
|
36 |
import java.net.URLEncoder; |
D |
37 |
import java.nio.file.Files; |
|
38 |
import java.util.*; |
|
39 |
import java.util.stream.Collectors; |
449017
|
40 |
|
D |
41 |
import static com.iailab.framework.common.pojo.CommonResult.error; |
|
42 |
import static com.iailab.framework.common.pojo.CommonResult.success; |
|
43 |
|
|
44 |
/** |
|
45 |
* @author PanZhibao |
|
46 |
* @Description |
|
47 |
* @createTime 2024年08月08日 |
|
48 |
*/ |
|
49 |
@RestController |
912a1e
|
50 |
@Slf4j |
449017
|
51 |
@RequestMapping("/model/mpk/api") |
D |
52 |
public class MdkController { |
1fea3e
|
53 |
@Value("${mpk.bak-file-path}") |
D |
54 |
private String mpkBakFilePath; |
|
55 |
|
f7e4a8
|
56 |
// 先进先出缓存 临时保存导入的数据 |
a8c6a6
|
57 |
private static FIFOCache<String, String> cache = CacheUtil.newFIFOCache(100); |
1fea3e
|
58 |
|
3e18d4
|
59 |
/** |
D |
60 |
* @description: 模型测试运行 |
|
61 |
* @author: dzd |
|
62 |
* @date: 2024/10/14 15:26 |
|
63 |
**/ |
|
64 |
@PostMapping("test") |
|
65 |
public CommonResult<String> test(@RequestBody MdkDTO dto) { |
558ffc
|
66 |
Long tenantId = TenantContextHolder.getTenantId(); |
D |
67 |
// 备份文件 租户隔离 |
|
68 |
String mpkTenantBakFilePath = mpkBakFilePath + File.separator + tenantId; |
|
69 |
|
912a1e
|
70 |
Class<?> clazz; |
D |
71 |
URLClassLoader classLoader; |
1fea3e
|
72 |
try { |
558ffc
|
73 |
File jarFile = new File(mpkTenantBakFilePath + File.separator + MdkConstant.JAR + File.separator + dto.getPyName() + ".jar"); |
1fea3e
|
74 |
if (!jarFile.exists()) { |
D |
75 |
throw new RuntimeException("jar包不存在,请先生成代码。jarPath:" + jarFile.getAbsolutePath()); |
|
76 |
} |
558ffc
|
77 |
File dllFile = new File(mpkTenantBakFilePath + File.separator + MdkConstant.DLL + File.separator + dto.getPyName() + ".dll"); |
1fea3e
|
78 |
if (!dllFile.exists()) { |
912a1e
|
79 |
throw new RuntimeException("dll文件不存在,请先生成代码。dllPath:" + dllFile.getAbsolutePath()); |
1fea3e
|
80 |
} |
D |
81 |
// 加载jar包 |
912a1e
|
82 |
classLoader = DllUtils.loadJar(jarFile.getAbsolutePath()); |
D |
83 |
// 实现类 |
|
84 |
clazz = classLoader.loadClass(dto.getClassName()); |
|
85 |
// 加载dll到实现类 |
|
86 |
DllUtils.loadDll(clazz,dllFile.getAbsolutePath()); |
1fea3e
|
87 |
} catch (Exception e) { |
912a1e
|
88 |
e.printStackTrace(); |
1fea3e
|
89 |
throw new RuntimeException("加载运行环境失败。"); |
D |
90 |
} |
|
91 |
|
449017
|
92 |
System.out.println("runTime=" + System.currentTimeMillis()); |
D |
93 |
try { |
a8c6a6
|
94 |
List<String> uuids = dto.getUuids(); |
449017
|
95 |
|
a8c6a6
|
96 |
int paramLength = dto.getHasModel() ? uuids.size() + 2 : uuids.size() + 1; |
449017
|
97 |
Object[] paramsValueArray = new Object[paramLength]; |
D |
98 |
Class<?>[] paramsArray = new Class[paramLength]; |
|
99 |
|
|
100 |
try { |
a8c6a6
|
101 |
for (int i = 0; i < uuids.size(); i++) { |
D |
102 |
String uuid = uuids.get(i); |
|
103 |
if (!cache.containsKey(uuid)) { |
|
104 |
return error(GlobalErrorCodeConstants.BAD_REQUEST.getCode(),"请重新导入模型参数"); |
|
105 |
} |
|
106 |
JSONArray jsonArray = JSON.parseArray(cache.get(uuid)); |
449017
|
107 |
double[][] data = new double[jsonArray.size()][jsonArray.getJSONArray(0).size()]; |
D |
108 |
for (int j = 0; j < jsonArray.size(); j++) { |
|
109 |
for (int k = 0; k < jsonArray.getJSONArray(j).size(); k++) { |
|
110 |
data[j][k] = jsonArray.getJSONArray(j).getDoubleValue(k); |
|
111 |
} |
|
112 |
} |
|
113 |
paramsValueArray[i] = data; |
|
114 |
paramsArray[i] = double[][].class; |
|
115 |
} |
|
116 |
} catch (Exception e) { |
|
117 |
e.printStackTrace(); |
a8c6a6
|
118 |
return error(GlobalErrorCodeConstants.BAD_REQUEST.getCode(),"模型参数错误,请检查!"); |
449017
|
119 |
} |
D |
120 |
|
a8c6a6
|
121 |
try { |
4f4b05
|
122 |
if (dto.getModelSettings().stream().noneMatch(e -> e.getSettingKey().equals(MdkConstant.PY_FILE_KEY))) { |
D |
123 |
return error(GlobalErrorCodeConstants.BAD_REQUEST.getCode(),"模型设置参数缺少必要信息【" + MdkConstant.PY_FILE_KEY + "】,请重新上传模型!"); |
|
124 |
} |
|
125 |
|
a8c6a6
|
126 |
if (dto.getHasModel()) { |
D |
127 |
paramsValueArray[uuids.size()] = dto.getModel(); |
|
128 |
paramsValueArray[uuids.size() + 1] = handleModelSettings(dto.getModelSettings()); |
|
129 |
paramsArray[uuids.size()] = HashMap.class; |
|
130 |
paramsArray[uuids.size() + 1] = HashMap.class; |
|
131 |
}else { |
|
132 |
paramsValueArray[uuids.size()] = handleModelSettings(dto.getModelSettings()); |
|
133 |
paramsArray[uuids.size()] = HashMap.class; |
|
134 |
} |
|
135 |
} catch (Exception e) { |
|
136 |
e.printStackTrace(); |
|
137 |
return error(GlobalErrorCodeConstants.BAD_REQUEST.getCode(),"模型设置错误,请检查!"); |
449017
|
138 |
} |
D |
139 |
|
912a1e
|
140 |
HashMap result = (HashMap) clazz.getDeclaredMethod(dto.getMethodName(), paramsArray).invoke(clazz.newInstance(), paramsValueArray); |
449017
|
141 |
return success(JSON.toJSONString(result)); |
D |
142 |
} catch (Exception ex) { |
|
143 |
ex.printStackTrace(); |
|
144 |
return error(GlobalErrorCodeConstants.INTERNAL_SERVER_ERROR.getCode(),"运行异常"); |
|
145 |
} finally { |
912a1e
|
146 |
if (classLoader != null) { |
c12dae
|
147 |
DllUtils.unloadDll(classLoader); |
D |
148 |
DllUtils.unloadJar(classLoader); |
912a1e
|
149 |
} |
449017
|
150 |
System.gc(); |
D |
151 |
} |
a8c6a6
|
152 |
} |
D |
153 |
|
f7e4a8
|
154 |
private IAILModel createModelBean(MdkDTO dto) { |
D |
155 |
IAILModel modelBean = new IAILModel(); |
|
156 |
|
|
157 |
//ParamPathList |
|
158 |
List<String> paramPathList = new ArrayList<>(); |
|
159 |
List<String> paramNameList = new ArrayList<>(); |
|
160 |
|
|
161 |
for (Map.Entry<String, Object> entry : dto.getModel().entrySet()) { |
|
162 |
paramNameList.add(entry.getKey()); |
|
163 |
paramPathList.add(entry.getValue().toString()); |
|
164 |
} |
|
165 |
modelBean.setParamNameList(paramNameList); |
|
166 |
modelBean.setParamPathList(paramPathList); |
|
167 |
//ClassName MethodName |
|
168 |
modelBean.setClassName(dto.getClassName()); |
|
169 |
modelBean.setMethodName(dto.getMethodName()); |
|
170 |
//ParamsArray |
|
171 |
int paramLength = dto.getHasModel() ? dto.getDataLength() + 2 : dto.getDataLength() + 1; |
|
172 |
Class<?>[] paramsArray = new Class[paramLength]; |
|
173 |
|
|
174 |
for (int i = 0; i < dto.getDataLength(); i++) { |
|
175 |
paramsArray[i] = double[][].class; |
|
176 |
} |
|
177 |
|
|
178 |
if (dto.getHasModel()) { |
|
179 |
paramsArray[dto.getDataLength()] = HashMap.class; |
|
180 |
paramsArray[dto.getDataLength() + 1] = HashMap.class; |
|
181 |
}else { |
|
182 |
paramsArray[dto.getDataLength()] = HashMap.class; |
|
183 |
} |
|
184 |
modelBean.setParamsArray(paramsArray); |
|
185 |
//LoadFieldSetList |
|
186 |
List<FieldSet> loadFieldSetList = new ArrayList<>(); |
|
187 |
FieldSet fieldSet = new FieldSet(); |
|
188 |
fieldSet.setFieldName(""); |
|
189 |
List<Property> propertyList = new ArrayList<>(); |
|
190 |
for (MethodSettingDTO modelSetting : dto.getPredModelSettings()) { |
|
191 |
Property property = new Property(); |
|
192 |
property.setKey(modelSetting.getSettingKey()); |
|
193 |
property.setName(modelSetting.getName()); |
|
194 |
property.setType(modelSetting.getType()); |
|
195 |
property.setValueType(modelSetting.getValueType()); |
|
196 |
property.setMin(modelSetting.getMin() == null ? "" : modelSetting.getMin().toString()); |
|
197 |
property.setMax(modelSetting.getMax() == null ? "" : modelSetting.getMax().toString()); |
|
198 |
property.setSelectItemList(CollectionUtils.isEmpty(modelSetting.getSettingSelects()) ? null : modelSetting.getSettingSelects().stream().map(e -> new SelectItem(e.getSelectKey(),e.getName())).collect(Collectors.toList())); |
|
199 |
property.setValue(modelSetting.getValue()); |
|
200 |
property.setFlow(false); |
|
201 |
propertyList.add(property); |
|
202 |
} |
|
203 |
fieldSet.setPropertyList(propertyList); |
|
204 |
loadFieldSetList.add(fieldSet); |
|
205 |
modelBean.setLoadFieldSetList(loadFieldSetList); |
|
206 |
//SettingConfigMap |
|
207 |
Map<String, Object> settingConfigMap = new HashMap<String, Object>(); |
|
208 |
List<com.iail.bean.Value> settingKeyList = new ArrayList<com.iail.bean.Value>(); |
|
209 |
Map<String, Object> settingMap = new HashMap<String, Object>(); |
|
210 |
for (MethodSettingDTO modelSetting : dto.getModelSettings()) { |
|
211 |
settingKeyList.add(new com.iail.bean.Value(modelSetting.getSettingKey(),modelSetting.getSettingKey())); |
|
212 |
settingConfigMap.put("settingKeyList", settingKeyList); |
|
213 |
settingConfigMap.put("settingMap", handleModelSettings(dto.getModelSettings())); |
|
214 |
} |
|
215 |
modelBean.setSettingConfigMap(settingConfigMap); |
|
216 |
//DataMap |
|
217 |
modelBean.setDataMap(dto.getModelResult()); |
|
218 |
//ResultKey |
|
219 |
modelBean.setResultKey(dto.getResultKey()); |
|
220 |
//ResultKey |
|
221 |
modelBean.setVersion("1.0.0"); |
|
222 |
|
|
223 |
|
|
224 |
return modelBean; |
|
225 |
} |
|
226 |
|
|
227 |
@PostMapping("saveModel") |
|
228 |
public void saveModel(@RequestBody MdkDTO dto, HttpServletResponse response) { |
|
229 |
IAILModel modelBean = createModelBean(dto); |
|
230 |
|
|
231 |
try { |
|
232 |
//临时文件夹 |
|
233 |
File tempFile = null; |
|
234 |
try { |
|
235 |
tempFile = Files.createTempFile(dto.getPyName(),".miail").toFile(); |
|
236 |
log.info("生成临时文件," + tempFile.getAbsolutePath()); |
|
237 |
} catch (IOException e) { |
|
238 |
throw new RuntimeException("创建临时文件异常",e); |
|
239 |
} |
|
240 |
|
|
241 |
|
|
242 |
|
|
243 |
try { |
|
244 |
IAILMDK.saveModel(tempFile, modelBean); |
|
245 |
} catch (Exception e) { |
|
246 |
throw new RuntimeException("IAILMDK.saveModel异常",e); |
|
247 |
} |
|
248 |
|
|
249 |
byte[] data = FileUtil.readBytes(tempFile); |
|
250 |
response.reset(); |
|
251 |
response.setHeader("Content-Disposition", "attachment; filename=\"" + URLEncoder.encode(tempFile.getName(), "UTF-8") + "\""); |
|
252 |
response.addHeader("Content-Length", "" + data.length); |
|
253 |
response.setContentType("application/octet-stream; charset=UTF-8"); |
|
254 |
|
|
255 |
IOUtils.write(data, response.getOutputStream()); |
|
256 |
} catch (Exception e) { |
|
257 |
throw new RuntimeException("代码生成异常",e); |
|
258 |
} |
|
259 |
} |
|
260 |
|
a8c6a6
|
261 |
private HashMap<String, Object> handleModelSettings(List<MethodSettingDTO> modelSettings) { |
D |
262 |
HashMap<String, Object> resultMap = null; |
|
263 |
try { |
|
264 |
resultMap = new HashMap<>(modelSettings.size()); |
|
265 |
for (MethodSettingDTO modelSetting : modelSettings) { |
|
266 |
switch (modelSetting.getValueType()) { |
|
267 |
case "int": |
|
268 |
resultMap.put(modelSetting.getSettingKey(), Integer.valueOf(modelSetting.getSettingValue())); |
|
269 |
break; |
|
270 |
case "string": |
|
271 |
resultMap.put(modelSetting.getSettingKey(), modelSetting.getSettingValue()); |
|
272 |
break; |
|
273 |
case "decimal": |
|
274 |
resultMap.put(modelSetting.getSettingKey(), Double.valueOf(modelSetting.getSettingValue())); |
|
275 |
break; |
|
276 |
case "decimalArray": |
|
277 |
JSONArray jsonArray = JSON.parseArray(modelSetting.getSettingValue()); |
|
278 |
double[] doubles = new double[jsonArray.size()]; |
|
279 |
for (int i = 0; i < jsonArray.size(); i++) { |
|
280 |
doubles[i] = Double.valueOf(String.valueOf(jsonArray.get(i))); |
|
281 |
} |
|
282 |
resultMap.put(modelSetting.getSettingKey(), doubles); |
|
283 |
break; |
|
284 |
} |
|
285 |
} |
|
286 |
} catch (NumberFormatException e) { |
|
287 |
throw new RuntimeException("模型参数有误,请检查!!!"); |
|
288 |
} |
|
289 |
return resultMap; |
449017
|
290 |
} |
D |
291 |
|
3e18d4
|
292 |
/** |
D |
293 |
* @description: 模型运行 |
|
294 |
* @author: dzd |
|
295 |
* @date: 2024/10/14 15:26 |
|
296 |
**/ |
|
297 |
@PostMapping("run") |
|
298 |
public CommonResult<String> run(@RequestBody MdkRunDTO dto) { |
|
299 |
if (RSAUtils.checkLisenceBean().getCode() != 1) { |
|
300 |
log.error("Lisence 不可用!"); |
|
301 |
return CommonResult.error(GlobalErrorCodeConstants.INTERNAL_SERVER_ERROR.getCode(),"Lisence 不可用!"); |
|
302 |
} else { |
|
303 |
try { |
|
304 |
URLClassLoader classLoader = DllUtils.getClassLoader(dto.getProjectId()); |
|
305 |
if (null == classLoader) { |
|
306 |
return CommonResult.error(GlobalErrorCodeConstants.ERROR_CONFIGURATION.getCode(),"请先发布项目!"); |
|
307 |
} |
|
308 |
|
|
309 |
Class<?> clazz = classLoader.loadClass(dto.getClassName()); |
|
310 |
Method method = clazz.getMethod(dto.getMethodName(), dto.getParamsClassArray()); |
|
311 |
HashMap invoke = (HashMap) method.invoke(clazz.newInstance(), dto.getParamsValueArray()); |
|
312 |
|
|
313 |
// todo 将结果存入数据库 |
|
314 |
|
|
315 |
} catch (Exception e) { |
|
316 |
log.error("模型运行失败",e); |
|
317 |
return CommonResult.error(GlobalErrorCodeConstants.INTERNAL_SERVER_ERROR.getCode(),"模型运行失败!"); |
|
318 |
} |
|
319 |
} |
b131f0
|
320 |
return CommonResult.success(""); |
3e18d4
|
321 |
} |
D |
322 |
|
449017
|
323 |
@PostMapping("/import") |
D |
324 |
@Operation(summary = "导入参数") |
a8c6a6
|
325 |
public CommonResult<List<HashMap<String,Object>>> importExcel(@RequestParam("file") MultipartFile file) throws Exception { |
449017
|
326 |
List<double[][]> datas = Readtxt.readMethodExcel(file); |
a8c6a6
|
327 |
List<HashMap<String,Object>> result = new ArrayList<>(); |
449017
|
328 |
if (!CollectionUtils.isEmpty(datas)) { |
D |
329 |
for (double[][] data : datas) { |
|
330 |
if (data.length > 0) { |
a8c6a6
|
331 |
HashMap<String,Object> map = new HashMap<>(); |
D |
332 |
String uuid = UUID.randomUUID().toString(); |
|
333 |
map.put("uuid",uuid); |
|
334 |
map.put("data",JSON.toJSONString(data)); |
|
335 |
cache.put(uuid,JSON.toJSONString(data)); |
|
336 |
result.add(map); |
449017
|
337 |
} |
D |
338 |
} |
|
339 |
} |
|
340 |
return success(result); |
1fea3e
|
341 |
} |
449017
|
342 |
} |