dengzedong
2025-02-27 3e61b6d86d6a98214e56c652a36a2290d471a695
提交 | 用户 | 时间
3e61b6 1 package com.iailab.module.model.matlab.common.utils;
D 2
3 import com.alibaba.fastjson.JSON;
4 import com.alibaba.fastjson.JSONArray;
5 import com.iail.model.IAILModel;
6 import com.iailab.module.model.matlab.common.exceptions.IllegalityJarException;
7 import com.iailab.module.model.matlab.dto.MatlabJarFileClassInfoDTO;
8 import com.iailab.module.model.matlab.dto.MatlabJarFileMethodInfoDTO;
9 import com.iailab.module.model.matlab.dto.MlModelMethodSettingDTO;
10 import com.iailab.module.model.mpk.common.MdkConstant;
11 import com.mathworks.toolbox.javabuilder.*;
12 import com.mathworks.toolbox.javabuilder.internal.MWFunctionSignature;
13 import lombok.extern.slf4j.Slf4j;
14
15 import java.io.File;
16 import java.io.IOException;
17 import java.lang.reflect.Field;
18 import java.lang.reflect.Method;
19 import java.lang.reflect.Modifier;
20 import java.net.MalformedURLException;
21 import java.net.URL;
22 import java.net.URLClassLoader;
23 import java.util.*;
24 import java.util.jar.JarEntry;
25 import java.util.jar.JarFile;
26 import java.util.stream.Collectors;
27
28 @Slf4j
29 public class MatlabUtils {
30
31     private static HashMap<String, URLClassLoader> classLoaderCache = new HashMap<>();
32     private static HashMap<String, Object> classCache = new HashMap<>();
33     private static HashMap<String, Method> classMethodCache = new HashMap<>();
34
35     /**
36      * @description: 解析matlab jar文件
37      **/
38     public static List<MatlabJarFileClassInfoDTO> parseJarInfo(String jarFilePath, String jatName) throws IllegalityJarException {
39         List<MatlabJarFileClassInfoDTO> classInfos = new ArrayList<>();
40         //加载jar用于解析内部class method
41         URLClassLoader urlClassLoader = loadJar(null,jarFilePath);
42         try (JarFile jarFile = new JarFile(jarFilePath)) {
43             // 获取 JAR 文件中所有的条目
44             Enumeration<JarEntry> entries = jarFile.entries();
45             while (entries.hasMoreElements()) {
46                 JarEntry entry = entries.nextElement();
47                 String entryName = entry.getName();
48                 // 检查条目是否为类文件
49                 if (entryName.endsWith(".class") && !entryName.endsWith("Remote.class") && !entryName.endsWith("MCRFactory.class")) {
50                     MatlabJarFileClassInfoDTO classInfo = new MatlabJarFileClassInfoDTO();
51                     // 将类文件的路径转换为类的全限定名
52                     String className = entryName.replace('/', '.').substring(0, entryName.lastIndexOf(".class"));
53                     // 校验包名是否和文件名一致,不一致则说明修改过jar包名,判定为非法包
54                     if (!className.startsWith(jatName)) {
55                         throw new IllegalityJarException();
56                     }
57
58                     classInfo.setClassName(className);
59                     try {
60                         // 加载类
61                         Class<?> clazz = urlClassLoader.loadClass(className);
62                         // 获取该类的所有属性
63                         Field[] fields = clazz.getDeclaredFields();
64                         List<MatlabJarFileMethodInfoDTO> methodInfos = new ArrayList<>();
65                         for (Field field : fields) {
66                             int modifiers = field.getModifiers();
67                             if (Modifier.isPrivate(modifiers) && Modifier.isStatic(modifiers) && Modifier.isFinal(modifiers)) {
68                                 try {
69                                     // 绕过访问控制检查
70                                     field.setAccessible(true);
71                                     Object value = field.get(null); // 静态字段不需要实例,传 null
72                                     if (value instanceof MWFunctionSignature) {
73                                         MatlabJarFileMethodInfoDTO methodInfo = new MatlabJarFileMethodInfoDTO();
74                                         methodInfo.setMethodName(value.getClass().getField("name").get(value).toString());
75                                         // 减1是为了排除最后面的setting
76                                         methodInfo.setDataLength((Integer) value.getClass().getField("numInputs").get(value) - 1);
77                                         methodInfo.setOutLength((Integer) value.getClass().getField("numOutputs").get(value));
78                                         methodInfos.add(methodInfo);
79                                     }
80                                 } catch (Exception e) {
81                                     log.error("get matlab method info exception,className:" + className, e);
82                                 }
83                             }
84                         }
85                         classInfo.setMethodInfos(methodInfos);
86                     } catch (Exception e) {
87                         log.error("get matlab class info exception,className:" + className, e);
88                     }
89                     classInfos.add(classInfo);
90                 }
91             }
92         } catch (IOException e) {
93             log.error("get matlab jar info exception,jarFilePath:" + jarFilePath, e);
94         } finally {
95             unloadJar(urlClassLoader);
96         }
97         return classInfos;
98     }
99
100     public static synchronized URLClassLoader loadJar(String projectId, String... jarPaths) {
101         try {
102             URL[] urls = new URL[jarPaths.length];
103             for (int i = 0; i < jarPaths.length; i++) {
104                 String jarPath = jarPaths[i];
105                 File jarFile = new File(jarPath);
106                 if (!jarFile.exists()) {
107                     throw new RuntimeException("jar沒有找到!" + jarPath);
108                 }
109                 urls[i] = new File(jarPath).toURI().toURL();
110             }
111
112             // 不能设置classloader的patent,让它使用双亲委派去找系统classloader中javabuilder.jar中的依赖。因为javabuilder.jar只能加载一次,加载多次会报资源占用
113             URLClassLoader urlClassLoader = new URLClassLoader(urls);
114             if (projectId != null) {
115                 addClassLoaderCache(projectId, urlClassLoader);
116             }
117             log.info("成功加载jar包:" + String.join(";",jarPaths));
118             return urlClassLoader;
119         } catch (Exception e) {
120             throw new RuntimeException("加载jar异常", e);
121         }
122     }
123
124     public static synchronized void unloadJar(URLClassLoader urlClassLoader) {
125         try {
126             urlClassLoader.close();
127             log.info("成功卸载jar包。");
128         } catch (Exception e) {
129             throw new RuntimeException("卸载jar异常", e);
130         }
131     }
132
133     public static synchronized void addClassLoaderCache(String projectId, URLClassLoader urlClassLoader) {
134         classLoaderCache.put(projectId, urlClassLoader);
135     }
136
137     public static synchronized URLClassLoader getClassLoader(String projectId) {
138         return classLoaderCache.get(projectId);
139     }
140
141     public static synchronized void removeClassLoaderCache(String projectId) {
142         if (classLoaderCache.containsKey(projectId)) {
143             URLClassLoader urlClassLoader = classLoaderCache.get(projectId);
144             unloadJar(urlClassLoader);
145             classLoaderCache.remove(projectId);
146             removeClassCache(projectId);
147             removeClassMethodCache(projectId);
148         }
149     }
150
151     public static synchronized void removeClassCache(String projectId) {
152         Iterator<String> iterator = classCache.keySet().iterator();
153         while (iterator.hasNext()) {
154             String key = iterator.next();
155             if (key.startsWith(projectId)) {
156                 iterator.remove();
157             }
158         }
159     }
160
161     public static synchronized void removeClassMethodCache(String projectId) {
162         Iterator<String> iterator = classMethodCache.keySet().iterator();
163         while (iterator.hasNext()) {
164             String key = iterator.next();
165             if (key.startsWith(projectId)) {
166                 iterator.remove();
167             }
168         }
169     }
170
171     public static void removeOldFile(String bakPath, String projectId) {
172         File dir = new File(bakPath);
173         if (dir.exists() && dir.isDirectory()) {
174             File[] files = dir.listFiles();
175             if (null != files && files.length > 0) {
176                 for (File file : files) {
177                     if (file.getName().startsWith(projectId)) {
178                         file.delete();
179                     }
180                 }
181             }
182         }
183     }
184
185     /**
186      * @description: 项目启动加载已发布的dll和jar
187      * @author: dzd
188      * @date: 2024/10/10 11:58
189      **/
190     public static void loadProjectPublish(String bakPath) {
191         File dir = new File(bakPath);
192         if (dir.exists() && dir.isDirectory()) {
193             File[] files = dir.listFiles();
194             if (null != files && files.length > 0) {
195                 for (File file : files) {
196                     String fileName = file.getName();
197                     if (fileName.endsWith(".jar")) {
198                         String[] split = fileName.split(MdkConstant.SPLIT);
199                         String projectId = split[0];
200
201                         URLClassLoader urlClassLoader = null;
202                         try {
203                             // 加载新的jar
204                             urlClassLoader = loadJar(projectId,file.getAbsolutePath());
205                         } catch (Exception e) {
206                             throw new RuntimeException("加载jar异常", e);
207                         }
208                         // 成功后加入缓存
209                         addClassLoaderCache(projectId, urlClassLoader);
210                     }
211                 }
212             }
213         }
214
215     }
216
217     public static HashMap<String, Object> run(IAILModel model, Object[] paramsValueArray, String projectId) throws Exception {
218         if (model == null) {
219             throw new RuntimeException("模型文件不能为空!");
220         } else {
221             // 上传时校验文件名和class的开头是相同的,保证className唯一,可以用 projectId_className 作为唯一key
222             String classCacheKey = projectId + "_" + model.getClassName();
223             String methodParams = Arrays.stream(model.getParamsArray()).map(e -> e.getName()).collect(Collectors.joining(","));
224             String classMethodCacheKey = classCacheKey + "." + model.getMethodName() + "(" + methodParams + ")";
225             // 因为一个类下可能有多多个方法,所以这里classCacheKey 和 classMethodCacheKey 要分开判断
226             if (classCache.containsKey(classCacheKey)) {
227                 if (classMethodCache.containsKey(classMethodCacheKey)) {
228                     return (HashMap) classMethodCache.get(classMethodCacheKey).invoke(classCache.get(classCacheKey), paramsValueArray);
229                 } else {
230                     // 运行过这个类的其他方法,类从缓存中取,新建方法
231                     Object o = classCache.get(classCacheKey);
232                     Class<?>[] paramsArray = new Class[2];
233                     paramsArray[0] = int.class;
234                     paramsArray[1] = Object[].class;
235                     Method method = o.getClass().getMethod(model.getMethodName(), paramsArray);
236                     classMethodCache.put(classMethodCacheKey, method);
237                     Object[] objects = (Object[]) method.invoke(o, paramsValueArray);
238                     HashMap<String, Object> map = convertStructToMap((MWStructArray) objects[0]);
239                     return map;
240                 }
241
242             } else {
243                 URLClassLoader classLoader = getClassLoader(projectId);
244                 if (null == classLoader) {
245                     throw new RuntimeException("matlab未发布,classLoader为null");
246                 }
247                 Class<?> clazz = classLoader.loadClass(model.getClassName());
248                 Object o = clazz.newInstance();
249                 Class<?>[] paramsArray = new Class[2];
250                 paramsArray[0] = int.class;
251                 paramsArray[1] = Object[].class;
252                 Method method = clazz.getMethod(model.getMethodName(), paramsArray);
253                 classCache.put(classCacheKey, o);
254                 classMethodCache.put(classMethodCacheKey, method);
255                 Object[] objects = (Object[]) method.invoke(o, paramsValueArray);
256                 HashMap<String, Object> map = convertStructToMap((MWStructArray) objects[0]);
257                 return map;
258             }
259         }
260     }
261
262     public static HashMap<String, Object> convertStructToMap(MWStructArray struct) {
263         HashMap<String, Object> map;
264         try {
265             map = new HashMap<>();
266             String[] fieldNames = struct.fieldNames();
267             int numElements = struct.numberOfElements();
268             for (int i = 1; i <= numElements; i++) {
269                 for (String fieldName : fieldNames) {
270                     MWArray value = struct.getField(fieldName, i);
271                     Object javaValue = convertMWArrayToJavaObject(value);
272                     map.put(fieldName, javaValue);
273                 }
274             }
275         } finally {
276             struct.dispose();
277         }
278         return map;
279     }
280
281     private static Object convertMWArrayToJavaObject(MWArray mwArray) {
282         try {
283             if (mwArray instanceof MWNumericArray) {
284                 MWNumericArray numArray = (MWNumericArray) mwArray;
285                 if (numArray.numberOfElements() == 1) {
286                     MWClassID mwClassID = numArray.classID();
287                     if (mwClassID.equals(MWClassID.DOUBLE)) {
288                         return numArray.getDouble();
289                     } else if (mwClassID.equals(MWClassID.SINGLE)) {
290                         return numArray.getFloat();
291                     } else if (mwClassID.equals(MWClassID.INT8) || mwClassID.equals(MWClassID.UINT8)) {
292                         return numArray.getByte();
293                     } else if (mwClassID.equals(MWClassID.INT16) || mwClassID.equals(MWClassID.UINT16)) {
294                         return numArray.getShort();
295                     } else if (mwClassID.equals(MWClassID.INT32) || mwClassID.equals(MWClassID.UINT32)) {
296                         return numArray.getInt();
297                     } else if (mwClassID.equals(MWClassID.INT64) || mwClassID.equals(MWClassID.UINT64)) {
298                         return numArray.getLong();
299                     } else if (mwClassID.equals(MWClassID.LOGICAL)) {
300                         return numArray.getByte();
301                     }
302                     return null;
303                 } else {
304                     MWClassID mwClassID = numArray.classID();
305                     if (mwClassID.equals(MWClassID.DOUBLE)) {
306                         return numArray.toDoubleArray();
307                     } else if (mwClassID.equals(MWClassID.SINGLE)) {
308                         return numArray.toFloatArray();
309                     } else if (mwClassID.equals(MWClassID.INT8)) {
310                         return numArray.toByteArray();
311                     } else if (mwClassID.equals(MWClassID.INT16) || mwClassID.equals(MWClassID.UINT16)) {
312                         return numArray.toShortArray();
313                     } else if (mwClassID.equals(MWClassID.INT32) || mwClassID.equals(MWClassID.UINT32)) {
314                         return numArray.toIntArray();
315                     } else if (mwClassID.equals(MWClassID.INT64) || mwClassID.equals(MWClassID.UINT64)) {
316                         return numArray.toLongArray();
317                     } else if (mwClassID.equals(MWClassID.LOGICAL)) {
318                         return numArray.toByteArray();
319                     }
320                     return null;
321                 }
322             } else if (mwArray instanceof MWCharArray) {
323                 MWCharArray stringArray = (MWCharArray) mwArray;
324                 if (stringArray.numberOfElements() == 1) {
325                     return stringArray.getChar(1);
326                 } else {
327                     // 不支持string,string都是用char[]返回,所以这里将char[]转为string
328                     int[] dimensions = stringArray.getDimensions();
329                     if (dimensions[0] == 1) {
330                         //string
331                         return stringArray.toString();
332                     } else {
333                         //String[] 暂时只考虑一维string
334                         char[][] array = (char[][]) stringArray.toArray();
335                         String[] list = new String[dimensions[0]];
336                         for (int i = 0; i < dimensions[0]; i++) {
337                             list[i] = (String.valueOf(array[i]));
338                         }
339                         return list;
340                     }
341                 }
342             } else if (mwArray instanceof MWLogicalArray) {
343                 MWLogicalArray logicalArray = (MWLogicalArray) mwArray;
344                 if (logicalArray.numberOfElements() == 1) {
345                     return logicalArray.getBoolean(1);
346                 } else {
347                     int[] dimensions = logicalArray.getDimensions();
348                     return logicalArray.toArray();
349                 }
350             } else if (mwArray instanceof MWStructArray) {
351                 MWStructArray structArray = (MWStructArray) mwArray;
352                 return convertStructToMap(structArray);
353             }
354         } finally {
355             mwArray.dispose();
356         }
357         return null;
358     }
359
360     public static HashMap<String, Object> handleModelSettings(List<MlModelMethodSettingDTO> modelSettings) {
361         HashMap<String, Object> resultMap = null;
362         try {
363             resultMap = new HashMap<>(modelSettings.size());
364             for (MlModelMethodSettingDTO modelSetting : modelSettings) {
365                 switch (modelSetting.getValueType()) {
366                     case "int":
367                         resultMap.put(modelSetting.getSettingKey(), Integer.valueOf(modelSetting.getSettingValue()));
368                         break;
369                     case "string":
370                         resultMap.put(modelSetting.getSettingKey(), modelSetting.getSettingValue());
371                         break;
372                     case "decimal":
373                         resultMap.put(modelSetting.getSettingKey(), Double.valueOf(modelSetting.getSettingValue()));
374                         break;
375                     case "decimalArray":
376                         JSONArray jsonArray = JSON.parseArray(modelSetting.getSettingValue());
377                         double[] doubles = new double[jsonArray.size()];
378                         for (int i = 0; i < jsonArray.size(); i++) {
379                             doubles[i] = Double.valueOf(String.valueOf(jsonArray.get(i)));
380                         }
381                         resultMap.put(modelSetting.getSettingKey(), doubles);
382                         break;
383                 }
384             }
385         } catch (NumberFormatException e) {
386             throw new RuntimeException("模型参数有误,请检查!!!");
387         }
388         return resultMap;
389     }
390
391     public static MWStructArray convertMapToStruct(Map<String, Object> map) {
392         String[] fieldNames = map.keySet().toArray(new String[0]);
393         MWStructArray struct = new MWStructArray(1, 1, fieldNames);
394         for (String key : map.keySet()) {
395             Object value = map.get(key);
396             if (value instanceof Number) {
397                 struct.set(key, 1, ((Number) value).doubleValue());
398             } else if (value instanceof String) {
399                 struct.set(key, 1, value);
400             }
401             // 可以根据需要添加更多类型的处理
402         }
403         return struct;
404     }
405 }