提交 | 用户 | 时间
|
3e61b6
|
1 |
package com.iailab.module.model.matlab.common.utils; |
D |
2 |
|
|
3 |
import com.alibaba.fastjson.JSON; |
|
4 |
import com.alibaba.fastjson.JSONArray; |
|
5 |
import com.iail.model.IAILModel; |
|
6 |
import com.iailab.module.model.matlab.common.exceptions.IllegalityJarException; |
|
7 |
import com.iailab.module.model.matlab.dto.MatlabJarFileClassInfoDTO; |
|
8 |
import com.iailab.module.model.matlab.dto.MatlabJarFileMethodInfoDTO; |
|
9 |
import com.iailab.module.model.matlab.dto.MlModelMethodSettingDTO; |
|
10 |
import com.iailab.module.model.mpk.common.MdkConstant; |
|
11 |
import com.mathworks.toolbox.javabuilder.*; |
|
12 |
import com.mathworks.toolbox.javabuilder.internal.MWFunctionSignature; |
|
13 |
import lombok.extern.slf4j.Slf4j; |
|
14 |
|
|
15 |
import java.io.File; |
|
16 |
import java.io.IOException; |
|
17 |
import java.lang.reflect.Field; |
|
18 |
import java.lang.reflect.Method; |
|
19 |
import java.lang.reflect.Modifier; |
|
20 |
import java.net.MalformedURLException; |
|
21 |
import java.net.URL; |
|
22 |
import java.net.URLClassLoader; |
|
23 |
import java.util.*; |
|
24 |
import java.util.jar.JarEntry; |
|
25 |
import java.util.jar.JarFile; |
|
26 |
import java.util.stream.Collectors; |
|
27 |
|
|
28 |
@Slf4j |
|
29 |
public class MatlabUtils { |
|
30 |
|
|
31 |
private static HashMap<String, URLClassLoader> classLoaderCache = new HashMap<>(); |
|
32 |
private static HashMap<String, Object> classCache = new HashMap<>(); |
|
33 |
private static HashMap<String, Method> classMethodCache = new HashMap<>(); |
|
34 |
|
|
35 |
/** |
|
36 |
* @description: 解析matlab jar文件 |
|
37 |
**/ |
|
38 |
public static List<MatlabJarFileClassInfoDTO> parseJarInfo(String jarFilePath, String jatName) throws IllegalityJarException { |
|
39 |
List<MatlabJarFileClassInfoDTO> classInfos = new ArrayList<>(); |
|
40 |
//加载jar用于解析内部class method |
|
41 |
URLClassLoader urlClassLoader = loadJar(null,jarFilePath); |
|
42 |
try (JarFile jarFile = new JarFile(jarFilePath)) { |
|
43 |
// 获取 JAR 文件中所有的条目 |
|
44 |
Enumeration<JarEntry> entries = jarFile.entries(); |
|
45 |
while (entries.hasMoreElements()) { |
|
46 |
JarEntry entry = entries.nextElement(); |
|
47 |
String entryName = entry.getName(); |
|
48 |
// 检查条目是否为类文件 |
|
49 |
if (entryName.endsWith(".class") && !entryName.endsWith("Remote.class") && !entryName.endsWith("MCRFactory.class")) { |
|
50 |
MatlabJarFileClassInfoDTO classInfo = new MatlabJarFileClassInfoDTO(); |
|
51 |
// 将类文件的路径转换为类的全限定名 |
|
52 |
String className = entryName.replace('/', '.').substring(0, entryName.lastIndexOf(".class")); |
|
53 |
// 校验包名是否和文件名一致,不一致则说明修改过jar包名,判定为非法包 |
|
54 |
if (!className.startsWith(jatName)) { |
|
55 |
throw new IllegalityJarException(); |
|
56 |
} |
|
57 |
|
|
58 |
classInfo.setClassName(className); |
|
59 |
try { |
|
60 |
// 加载类 |
|
61 |
Class<?> clazz = urlClassLoader.loadClass(className); |
|
62 |
// 获取该类的所有属性 |
|
63 |
Field[] fields = clazz.getDeclaredFields(); |
|
64 |
List<MatlabJarFileMethodInfoDTO> methodInfos = new ArrayList<>(); |
|
65 |
for (Field field : fields) { |
|
66 |
int modifiers = field.getModifiers(); |
|
67 |
if (Modifier.isPrivate(modifiers) && Modifier.isStatic(modifiers) && Modifier.isFinal(modifiers)) { |
|
68 |
try { |
|
69 |
// 绕过访问控制检查 |
|
70 |
field.setAccessible(true); |
|
71 |
Object value = field.get(null); // 静态字段不需要实例,传 null |
|
72 |
if (value instanceof MWFunctionSignature) { |
|
73 |
MatlabJarFileMethodInfoDTO methodInfo = new MatlabJarFileMethodInfoDTO(); |
|
74 |
methodInfo.setMethodName(value.getClass().getField("name").get(value).toString()); |
|
75 |
// 减1是为了排除最后面的setting |
|
76 |
methodInfo.setDataLength((Integer) value.getClass().getField("numInputs").get(value) - 1); |
|
77 |
methodInfo.setOutLength((Integer) value.getClass().getField("numOutputs").get(value)); |
|
78 |
methodInfos.add(methodInfo); |
|
79 |
} |
|
80 |
} catch (Exception e) { |
|
81 |
log.error("get matlab method info exception,className:" + className, e); |
|
82 |
} |
|
83 |
} |
|
84 |
} |
|
85 |
classInfo.setMethodInfos(methodInfos); |
|
86 |
} catch (Exception e) { |
|
87 |
log.error("get matlab class info exception,className:" + className, e); |
|
88 |
} |
|
89 |
classInfos.add(classInfo); |
|
90 |
} |
|
91 |
} |
|
92 |
} catch (IOException e) { |
|
93 |
log.error("get matlab jar info exception,jarFilePath:" + jarFilePath, e); |
|
94 |
} finally { |
|
95 |
unloadJar(urlClassLoader); |
|
96 |
} |
|
97 |
return classInfos; |
|
98 |
} |
|
99 |
|
|
100 |
public static synchronized URLClassLoader loadJar(String projectId, String... jarPaths) { |
|
101 |
try { |
|
102 |
URL[] urls = new URL[jarPaths.length]; |
|
103 |
for (int i = 0; i < jarPaths.length; i++) { |
|
104 |
String jarPath = jarPaths[i]; |
|
105 |
File jarFile = new File(jarPath); |
|
106 |
if (!jarFile.exists()) { |
|
107 |
throw new RuntimeException("jar沒有找到!" + jarPath); |
|
108 |
} |
|
109 |
urls[i] = new File(jarPath).toURI().toURL(); |
|
110 |
} |
|
111 |
|
|
112 |
// 不能设置classloader的patent,让它使用双亲委派去找系统classloader中javabuilder.jar中的依赖。因为javabuilder.jar只能加载一次,加载多次会报资源占用 |
|
113 |
URLClassLoader urlClassLoader = new URLClassLoader(urls); |
|
114 |
if (projectId != null) { |
|
115 |
addClassLoaderCache(projectId, urlClassLoader); |
|
116 |
} |
|
117 |
log.info("成功加载jar包:" + String.join(";",jarPaths)); |
|
118 |
return urlClassLoader; |
|
119 |
} catch (Exception e) { |
|
120 |
throw new RuntimeException("加载jar异常", e); |
|
121 |
} |
|
122 |
} |
|
123 |
|
|
124 |
public static synchronized void unloadJar(URLClassLoader urlClassLoader) { |
|
125 |
try { |
|
126 |
urlClassLoader.close(); |
|
127 |
log.info("成功卸载jar包。"); |
|
128 |
} catch (Exception e) { |
|
129 |
throw new RuntimeException("卸载jar异常", e); |
|
130 |
} |
|
131 |
} |
|
132 |
|
|
133 |
public static synchronized void addClassLoaderCache(String projectId, URLClassLoader urlClassLoader) { |
|
134 |
classLoaderCache.put(projectId, urlClassLoader); |
|
135 |
} |
|
136 |
|
|
137 |
public static synchronized URLClassLoader getClassLoader(String projectId) { |
|
138 |
return classLoaderCache.get(projectId); |
|
139 |
} |
|
140 |
|
|
141 |
public static synchronized void removeClassLoaderCache(String projectId) { |
|
142 |
if (classLoaderCache.containsKey(projectId)) { |
|
143 |
URLClassLoader urlClassLoader = classLoaderCache.get(projectId); |
|
144 |
unloadJar(urlClassLoader); |
|
145 |
classLoaderCache.remove(projectId); |
|
146 |
removeClassCache(projectId); |
|
147 |
removeClassMethodCache(projectId); |
|
148 |
} |
|
149 |
} |
|
150 |
|
|
151 |
public static synchronized void removeClassCache(String projectId) { |
|
152 |
Iterator<String> iterator = classCache.keySet().iterator(); |
|
153 |
while (iterator.hasNext()) { |
|
154 |
String key = iterator.next(); |
|
155 |
if (key.startsWith(projectId)) { |
|
156 |
iterator.remove(); |
|
157 |
} |
|
158 |
} |
|
159 |
} |
|
160 |
|
|
161 |
public static synchronized void removeClassMethodCache(String projectId) { |
|
162 |
Iterator<String> iterator = classMethodCache.keySet().iterator(); |
|
163 |
while (iterator.hasNext()) { |
|
164 |
String key = iterator.next(); |
|
165 |
if (key.startsWith(projectId)) { |
|
166 |
iterator.remove(); |
|
167 |
} |
|
168 |
} |
|
169 |
} |
|
170 |
|
|
171 |
public static void removeOldFile(String bakPath, String projectId) { |
|
172 |
File dir = new File(bakPath); |
|
173 |
if (dir.exists() && dir.isDirectory()) { |
|
174 |
File[] files = dir.listFiles(); |
|
175 |
if (null != files && files.length > 0) { |
|
176 |
for (File file : files) { |
|
177 |
if (file.getName().startsWith(projectId)) { |
|
178 |
file.delete(); |
|
179 |
} |
|
180 |
} |
|
181 |
} |
|
182 |
} |
|
183 |
} |
|
184 |
|
|
185 |
/** |
|
186 |
* @description: 项目启动加载已发布的dll和jar |
|
187 |
* @author: dzd |
|
188 |
* @date: 2024/10/10 11:58 |
|
189 |
**/ |
|
190 |
public static void loadProjectPublish(String bakPath) { |
|
191 |
File dir = new File(bakPath); |
|
192 |
if (dir.exists() && dir.isDirectory()) { |
|
193 |
File[] files = dir.listFiles(); |
|
194 |
if (null != files && files.length > 0) { |
|
195 |
for (File file : files) { |
|
196 |
String fileName = file.getName(); |
|
197 |
if (fileName.endsWith(".jar")) { |
|
198 |
String[] split = fileName.split(MdkConstant.SPLIT); |
|
199 |
String projectId = split[0]; |
|
200 |
|
|
201 |
URLClassLoader urlClassLoader = null; |
|
202 |
try { |
|
203 |
// 加载新的jar |
|
204 |
urlClassLoader = loadJar(projectId,file.getAbsolutePath()); |
|
205 |
} catch (Exception e) { |
|
206 |
throw new RuntimeException("加载jar异常", e); |
|
207 |
} |
|
208 |
// 成功后加入缓存 |
|
209 |
addClassLoaderCache(projectId, urlClassLoader); |
|
210 |
} |
|
211 |
} |
|
212 |
} |
|
213 |
} |
|
214 |
|
|
215 |
} |
|
216 |
|
|
217 |
public static HashMap<String, Object> run(IAILModel model, Object[] paramsValueArray, String projectId) throws Exception { |
|
218 |
if (model == null) { |
|
219 |
throw new RuntimeException("模型文件不能为空!"); |
|
220 |
} else { |
|
221 |
// 上传时校验文件名和class的开头是相同的,保证className唯一,可以用 projectId_className 作为唯一key |
|
222 |
String classCacheKey = projectId + "_" + model.getClassName(); |
|
223 |
String methodParams = Arrays.stream(model.getParamsArray()).map(e -> e.getName()).collect(Collectors.joining(",")); |
|
224 |
String classMethodCacheKey = classCacheKey + "." + model.getMethodName() + "(" + methodParams + ")"; |
|
225 |
// 因为一个类下可能有多多个方法,所以这里classCacheKey 和 classMethodCacheKey 要分开判断 |
|
226 |
if (classCache.containsKey(classCacheKey)) { |
|
227 |
if (classMethodCache.containsKey(classMethodCacheKey)) { |
|
228 |
return (HashMap) classMethodCache.get(classMethodCacheKey).invoke(classCache.get(classCacheKey), paramsValueArray); |
|
229 |
} else { |
|
230 |
// 运行过这个类的其他方法,类从缓存中取,新建方法 |
|
231 |
Object o = classCache.get(classCacheKey); |
|
232 |
Class<?>[] paramsArray = new Class[2]; |
|
233 |
paramsArray[0] = int.class; |
|
234 |
paramsArray[1] = Object[].class; |
|
235 |
Method method = o.getClass().getMethod(model.getMethodName(), paramsArray); |
|
236 |
classMethodCache.put(classMethodCacheKey, method); |
|
237 |
Object[] objects = (Object[]) method.invoke(o, paramsValueArray); |
|
238 |
HashMap<String, Object> map = convertStructToMap((MWStructArray) objects[0]); |
|
239 |
return map; |
|
240 |
} |
|
241 |
|
|
242 |
} else { |
|
243 |
URLClassLoader classLoader = getClassLoader(projectId); |
|
244 |
if (null == classLoader) { |
|
245 |
throw new RuntimeException("matlab未发布,classLoader为null"); |
|
246 |
} |
|
247 |
Class<?> clazz = classLoader.loadClass(model.getClassName()); |
|
248 |
Object o = clazz.newInstance(); |
|
249 |
Class<?>[] paramsArray = new Class[2]; |
|
250 |
paramsArray[0] = int.class; |
|
251 |
paramsArray[1] = Object[].class; |
|
252 |
Method method = clazz.getMethod(model.getMethodName(), paramsArray); |
|
253 |
classCache.put(classCacheKey, o); |
|
254 |
classMethodCache.put(classMethodCacheKey, method); |
|
255 |
Object[] objects = (Object[]) method.invoke(o, paramsValueArray); |
|
256 |
HashMap<String, Object> map = convertStructToMap((MWStructArray) objects[0]); |
|
257 |
return map; |
|
258 |
} |
|
259 |
} |
|
260 |
} |
|
261 |
|
|
262 |
public static HashMap<String, Object> convertStructToMap(MWStructArray struct) { |
|
263 |
HashMap<String, Object> map; |
|
264 |
try { |
|
265 |
map = new HashMap<>(); |
|
266 |
String[] fieldNames = struct.fieldNames(); |
|
267 |
int numElements = struct.numberOfElements(); |
|
268 |
for (int i = 1; i <= numElements; i++) { |
|
269 |
for (String fieldName : fieldNames) { |
|
270 |
MWArray value = struct.getField(fieldName, i); |
|
271 |
Object javaValue = convertMWArrayToJavaObject(value); |
|
272 |
map.put(fieldName, javaValue); |
|
273 |
} |
|
274 |
} |
|
275 |
} finally { |
|
276 |
struct.dispose(); |
|
277 |
} |
|
278 |
return map; |
|
279 |
} |
|
280 |
|
|
281 |
private static Object convertMWArrayToJavaObject(MWArray mwArray) { |
|
282 |
try { |
|
283 |
if (mwArray instanceof MWNumericArray) { |
|
284 |
MWNumericArray numArray = (MWNumericArray) mwArray; |
|
285 |
if (numArray.numberOfElements() == 1) { |
|
286 |
MWClassID mwClassID = numArray.classID(); |
|
287 |
if (mwClassID.equals(MWClassID.DOUBLE)) { |
|
288 |
return numArray.getDouble(); |
|
289 |
} else if (mwClassID.equals(MWClassID.SINGLE)) { |
|
290 |
return numArray.getFloat(); |
|
291 |
} else if (mwClassID.equals(MWClassID.INT8) || mwClassID.equals(MWClassID.UINT8)) { |
|
292 |
return numArray.getByte(); |
|
293 |
} else if (mwClassID.equals(MWClassID.INT16) || mwClassID.equals(MWClassID.UINT16)) { |
|
294 |
return numArray.getShort(); |
|
295 |
} else if (mwClassID.equals(MWClassID.INT32) || mwClassID.equals(MWClassID.UINT32)) { |
|
296 |
return numArray.getInt(); |
|
297 |
} else if (mwClassID.equals(MWClassID.INT64) || mwClassID.equals(MWClassID.UINT64)) { |
|
298 |
return numArray.getLong(); |
|
299 |
} else if (mwClassID.equals(MWClassID.LOGICAL)) { |
|
300 |
return numArray.getByte(); |
|
301 |
} |
|
302 |
return null; |
|
303 |
} else { |
|
304 |
MWClassID mwClassID = numArray.classID(); |
|
305 |
if (mwClassID.equals(MWClassID.DOUBLE)) { |
|
306 |
return numArray.toDoubleArray(); |
|
307 |
} else if (mwClassID.equals(MWClassID.SINGLE)) { |
|
308 |
return numArray.toFloatArray(); |
|
309 |
} else if (mwClassID.equals(MWClassID.INT8)) { |
|
310 |
return numArray.toByteArray(); |
|
311 |
} else if (mwClassID.equals(MWClassID.INT16) || mwClassID.equals(MWClassID.UINT16)) { |
|
312 |
return numArray.toShortArray(); |
|
313 |
} else if (mwClassID.equals(MWClassID.INT32) || mwClassID.equals(MWClassID.UINT32)) { |
|
314 |
return numArray.toIntArray(); |
|
315 |
} else if (mwClassID.equals(MWClassID.INT64) || mwClassID.equals(MWClassID.UINT64)) { |
|
316 |
return numArray.toLongArray(); |
|
317 |
} else if (mwClassID.equals(MWClassID.LOGICAL)) { |
|
318 |
return numArray.toByteArray(); |
|
319 |
} |
|
320 |
return null; |
|
321 |
} |
|
322 |
} else if (mwArray instanceof MWCharArray) { |
|
323 |
MWCharArray stringArray = (MWCharArray) mwArray; |
|
324 |
if (stringArray.numberOfElements() == 1) { |
|
325 |
return stringArray.getChar(1); |
|
326 |
} else { |
|
327 |
// 不支持string,string都是用char[]返回,所以这里将char[]转为string |
|
328 |
int[] dimensions = stringArray.getDimensions(); |
|
329 |
if (dimensions[0] == 1) { |
|
330 |
//string |
|
331 |
return stringArray.toString(); |
|
332 |
} else { |
|
333 |
//String[] 暂时只考虑一维string |
|
334 |
char[][] array = (char[][]) stringArray.toArray(); |
|
335 |
String[] list = new String[dimensions[0]]; |
|
336 |
for (int i = 0; i < dimensions[0]; i++) { |
|
337 |
list[i] = (String.valueOf(array[i])); |
|
338 |
} |
|
339 |
return list; |
|
340 |
} |
|
341 |
} |
|
342 |
} else if (mwArray instanceof MWLogicalArray) { |
|
343 |
MWLogicalArray logicalArray = (MWLogicalArray) mwArray; |
|
344 |
if (logicalArray.numberOfElements() == 1) { |
|
345 |
return logicalArray.getBoolean(1); |
|
346 |
} else { |
|
347 |
int[] dimensions = logicalArray.getDimensions(); |
|
348 |
return logicalArray.toArray(); |
|
349 |
} |
|
350 |
} else if (mwArray instanceof MWStructArray) { |
|
351 |
MWStructArray structArray = (MWStructArray) mwArray; |
|
352 |
return convertStructToMap(structArray); |
|
353 |
} |
|
354 |
} finally { |
|
355 |
mwArray.dispose(); |
|
356 |
} |
|
357 |
return null; |
|
358 |
} |
|
359 |
|
|
360 |
public static HashMap<String, Object> handleModelSettings(List<MlModelMethodSettingDTO> modelSettings) { |
|
361 |
HashMap<String, Object> resultMap = null; |
|
362 |
try { |
|
363 |
resultMap = new HashMap<>(modelSettings.size()); |
|
364 |
for (MlModelMethodSettingDTO modelSetting : modelSettings) { |
|
365 |
switch (modelSetting.getValueType()) { |
|
366 |
case "int": |
|
367 |
resultMap.put(modelSetting.getSettingKey(), Integer.valueOf(modelSetting.getSettingValue())); |
|
368 |
break; |
|
369 |
case "string": |
|
370 |
resultMap.put(modelSetting.getSettingKey(), modelSetting.getSettingValue()); |
|
371 |
break; |
|
372 |
case "decimal": |
|
373 |
resultMap.put(modelSetting.getSettingKey(), Double.valueOf(modelSetting.getSettingValue())); |
|
374 |
break; |
|
375 |
case "decimalArray": |
|
376 |
JSONArray jsonArray = JSON.parseArray(modelSetting.getSettingValue()); |
|
377 |
double[] doubles = new double[jsonArray.size()]; |
|
378 |
for (int i = 0; i < jsonArray.size(); i++) { |
|
379 |
doubles[i] = Double.valueOf(String.valueOf(jsonArray.get(i))); |
|
380 |
} |
|
381 |
resultMap.put(modelSetting.getSettingKey(), doubles); |
|
382 |
break; |
|
383 |
} |
|
384 |
} |
|
385 |
} catch (NumberFormatException e) { |
|
386 |
throw new RuntimeException("模型参数有误,请检查!!!"); |
|
387 |
} |
|
388 |
return resultMap; |
|
389 |
} |
|
390 |
|
|
391 |
public static MWStructArray convertMapToStruct(Map<String, Object> map) { |
|
392 |
String[] fieldNames = map.keySet().toArray(new String[0]); |
|
393 |
MWStructArray struct = new MWStructArray(1, 1, fieldNames); |
|
394 |
for (String key : map.keySet()) { |
|
395 |
Object value = map.get(key); |
|
396 |
if (value instanceof Number) { |
|
397 |
struct.set(key, 1, ((Number) value).doubleValue()); |
|
398 |
} else if (value instanceof String) { |
|
399 |
struct.set(key, 1, value); |
|
400 |
} |
|
401 |
// 可以根据需要添加更多类型的处理 |
|
402 |
} |
|
403 |
return struct; |
|
404 |
} |
|
405 |
} |