潘志宝
2025-03-03 142bcd3bd15e9ba6176bb2093eee22040da9bd8c
提交 | 用户 | 时间
3e61b6 1 package com.iailab.module.model.matlab.common.utils;
D 2
3 import com.alibaba.fastjson.JSON;
4 import com.alibaba.fastjson.JSONArray;
5 import com.iail.model.IAILModel;
6 import com.iailab.module.model.matlab.common.exceptions.IllegalityJarException;
7 import com.iailab.module.model.matlab.dto.MatlabJarFileClassInfoDTO;
8 import com.iailab.module.model.matlab.dto.MatlabJarFileMethodInfoDTO;
9 import com.iailab.module.model.matlab.dto.MlModelMethodSettingDTO;
10 import com.iailab.module.model.mpk.common.MdkConstant;
11 import com.mathworks.toolbox.javabuilder.*;
12 import com.mathworks.toolbox.javabuilder.internal.MWFunctionSignature;
13 import lombok.extern.slf4j.Slf4j;
087ffc 14 import org.springframework.util.CollectionUtils;
3e61b6 15
D 16 import java.io.File;
17 import java.io.IOException;
18 import java.lang.reflect.Field;
19 import java.lang.reflect.Method;
20 import java.lang.reflect.Modifier;
21 import java.net.MalformedURLException;
22 import java.net.URL;
23 import java.net.URLClassLoader;
24 import java.util.*;
25 import java.util.jar.JarEntry;
26 import java.util.jar.JarFile;
27 import java.util.stream.Collectors;
28
29 @Slf4j
30 public class MatlabUtils {
31
32     private static HashMap<String, URLClassLoader> classLoaderCache = new HashMap<>();
33     private static HashMap<String, Object> classCache = new HashMap<>();
34     private static HashMap<String, Method> classMethodCache = new HashMap<>();
35
36     /**
37      * @description: 解析matlab jar文件
38      **/
39     public static List<MatlabJarFileClassInfoDTO> parseJarInfo(String jarFilePath, String jatName) throws IllegalityJarException {
40         List<MatlabJarFileClassInfoDTO> classInfos = new ArrayList<>();
41         //加载jar用于解析内部class method
42         URLClassLoader urlClassLoader = loadJar(null,jarFilePath);
43         try (JarFile jarFile = new JarFile(jarFilePath)) {
44             // 获取 JAR 文件中所有的条目
45             Enumeration<JarEntry> entries = jarFile.entries();
46             while (entries.hasMoreElements()) {
47                 JarEntry entry = entries.nextElement();
48                 String entryName = entry.getName();
49                 // 检查条目是否为类文件
50                 if (entryName.endsWith(".class") && !entryName.endsWith("Remote.class") && !entryName.endsWith("MCRFactory.class")) {
51                     MatlabJarFileClassInfoDTO classInfo = new MatlabJarFileClassInfoDTO();
52                     // 将类文件的路径转换为类的全限定名
53                     String className = entryName.replace('/', '.').substring(0, entryName.lastIndexOf(".class"));
54                     // 校验包名是否和文件名一致,不一致则说明修改过jar包名,判定为非法包
55                     if (!className.startsWith(jatName)) {
56                         throw new IllegalityJarException();
57                     }
58
59                     classInfo.setClassName(className);
60                     try {
61                         // 加载类
62                         Class<?> clazz = urlClassLoader.loadClass(className);
63                         // 获取该类的所有属性
64                         Field[] fields = clazz.getDeclaredFields();
65                         List<MatlabJarFileMethodInfoDTO> methodInfos = new ArrayList<>();
66                         for (Field field : fields) {
67                             int modifiers = field.getModifiers();
68                             if (Modifier.isPrivate(modifiers) && Modifier.isStatic(modifiers) && Modifier.isFinal(modifiers)) {
69                                 try {
70                                     // 绕过访问控制检查
71                                     field.setAccessible(true);
72                                     Object value = field.get(null); // 静态字段不需要实例,传 null
73                                     if (value instanceof MWFunctionSignature) {
74                                         MatlabJarFileMethodInfoDTO methodInfo = new MatlabJarFileMethodInfoDTO();
75                                         methodInfo.setMethodName(value.getClass().getField("name").get(value).toString());
76                                         // 减1是为了排除最后面的setting
77                                         methodInfo.setDataLength((Integer) value.getClass().getField("numInputs").get(value) - 1);
78                                         methodInfo.setOutLength((Integer) value.getClass().getField("numOutputs").get(value));
79                                         methodInfos.add(methodInfo);
80                                     }
81                                 } catch (Exception e) {
82                                     log.error("get matlab method info exception,className:" + className, e);
83                                 }
84                             }
85                         }
86                         classInfo.setMethodInfos(methodInfos);
87                     } catch (Exception e) {
88                         log.error("get matlab class info exception,className:" + className, e);
89                     }
90                     classInfos.add(classInfo);
91                 }
92             }
93         } catch (IOException e) {
94             log.error("get matlab jar info exception,jarFilePath:" + jarFilePath, e);
95         } finally {
96             unloadJar(urlClassLoader);
97         }
98         return classInfos;
99     }
100
101     public static synchronized URLClassLoader loadJar(String projectId, String... jarPaths) {
102         try {
103             URL[] urls = new URL[jarPaths.length];
104             for (int i = 0; i < jarPaths.length; i++) {
105                 String jarPath = jarPaths[i];
106                 File jarFile = new File(jarPath);
107                 if (!jarFile.exists()) {
108                     throw new RuntimeException("jar沒有找到!" + jarPath);
109                 }
110                 urls[i] = new File(jarPath).toURI().toURL();
111             }
112
113             // 不能设置classloader的patent,让它使用双亲委派去找系统classloader中javabuilder.jar中的依赖。因为javabuilder.jar只能加载一次,加载多次会报资源占用
114             URLClassLoader urlClassLoader = new URLClassLoader(urls);
115             if (projectId != null) {
116                 addClassLoaderCache(projectId, urlClassLoader);
117             }
118             log.info("成功加载jar包:" + String.join(";",jarPaths));
119             return urlClassLoader;
120         } catch (Exception e) {
121             throw new RuntimeException("加载jar异常", e);
122         }
123     }
124
125     public static synchronized void unloadJar(URLClassLoader urlClassLoader) {
126         try {
127             urlClassLoader.close();
128             log.info("成功卸载jar包。");
129         } catch (Exception e) {
130             throw new RuntimeException("卸载jar异常", e);
131         }
132     }
133
134     public static synchronized void addClassLoaderCache(String projectId, URLClassLoader urlClassLoader) {
135         classLoaderCache.put(projectId, urlClassLoader);
136     }
137
138     public static synchronized URLClassLoader getClassLoader(String projectId) {
139         return classLoaderCache.get(projectId);
140     }
141
142     public static synchronized void removeClassLoaderCache(String projectId) {
143         if (classLoaderCache.containsKey(projectId)) {
144             URLClassLoader urlClassLoader = classLoaderCache.get(projectId);
145             unloadJar(urlClassLoader);
146             classLoaderCache.remove(projectId);
147             removeClassCache(projectId);
148             removeClassMethodCache(projectId);
149         }
150     }
151
152     public static synchronized void removeClassCache(String projectId) {
153         Iterator<String> iterator = classCache.keySet().iterator();
154         while (iterator.hasNext()) {
155             String key = iterator.next();
156             if (key.startsWith(projectId)) {
157                 iterator.remove();
158             }
159         }
160     }
161
162     public static synchronized void removeClassMethodCache(String projectId) {
163         Iterator<String> iterator = classMethodCache.keySet().iterator();
164         while (iterator.hasNext()) {
165             String key = iterator.next();
166             if (key.startsWith(projectId)) {
167                 iterator.remove();
168             }
169         }
170     }
171
172     public static void removeOldFile(String bakPath, String projectId) {
173         File dir = new File(bakPath);
174         if (dir.exists() && dir.isDirectory()) {
175             File[] files = dir.listFiles();
176             if (null != files && files.length > 0) {
177                 for (File file : files) {
178                     if (file.getName().startsWith(projectId)) {
179                         file.delete();
180                     }
181                 }
182             }
183         }
184     }
185
186     /**
087ffc 187      * @description: 项目启动加载已发布的jar
3e61b6 188      * @author: dzd
D 189      * @date: 2024/10/10 11:58
190      **/
191     public static void loadProjectPublish(String bakPath) {
192         File dir = new File(bakPath);
193         if (dir.exists() && dir.isDirectory()) {
194             File[] files = dir.listFiles();
195             if (null != files && files.length > 0) {
087ffc 196                 HashMap<String,List<String>> projectIdJarFilePaths = new HashMap<>();
3e61b6 197                 for (File file : files) {
D 198                     String fileName = file.getName();
199                     if (fileName.endsWith(".jar")) {
200                         String[] split = fileName.split(MdkConstant.SPLIT);
201                         String projectId = split[0];
087ffc 202                         if (projectId != null) {
D 203                             if (projectIdJarFilePaths.containsKey(projectId)) {
204                                 projectIdJarFilePaths.get(projectId).add(file.getAbsolutePath());
205                             } else {
206                                 projectIdJarFilePaths.put(projectId,new ArrayList<String>(){{add(file.getAbsolutePath());}});
207                             }
3e61b6 208                         }
D 209                     }
210                 }
087ffc 211
D 212                 try {
213                     if (!CollectionUtils.isEmpty(projectIdJarFilePaths)) {
214                         for (Map.Entry<String, List<String>> entry : projectIdJarFilePaths.entrySet()) {
215                             // 加载新的jar
216                             loadJar(entry.getKey(),entry.getValue().toArray(new String[0]));
217                         }
218                     }
219                 } catch (Exception e) {
220                     throw new RuntimeException("加载jar异常", e);
221                 }
3e61b6 222             }
D 223         }
224
225     }
226
227     public static HashMap<String, Object> run(IAILModel model, Object[] paramsValueArray, String projectId) throws Exception {
228         if (model == null) {
229             throw new RuntimeException("模型文件不能为空!");
230         } else {
231             // 上传时校验文件名和class的开头是相同的,保证className唯一,可以用 projectId_className 作为唯一key
232             String classCacheKey = projectId + "_" + model.getClassName();
233             String methodParams = Arrays.stream(model.getParamsArray()).map(e -> e.getName()).collect(Collectors.joining(","));
234             String classMethodCacheKey = classCacheKey + "." + model.getMethodName() + "(" + methodParams + ")";
235             // 因为一个类下可能有多多个方法,所以这里classCacheKey 和 classMethodCacheKey 要分开判断
236             if (classCache.containsKey(classCacheKey)) {
237                 if (classMethodCache.containsKey(classMethodCacheKey)) {
238                     return (HashMap) classMethodCache.get(classMethodCacheKey).invoke(classCache.get(classCacheKey), paramsValueArray);
239                 } else {
240                     // 运行过这个类的其他方法,类从缓存中取,新建方法
241                     Object o = classCache.get(classCacheKey);
242                     Class<?>[] paramsArray = new Class[2];
243                     paramsArray[0] = int.class;
244                     paramsArray[1] = Object[].class;
245                     Method method = o.getClass().getMethod(model.getMethodName(), paramsArray);
246                     classMethodCache.put(classMethodCacheKey, method);
247                     Object[] objects = (Object[]) method.invoke(o, paramsValueArray);
248                     HashMap<String, Object> map = convertStructToMap((MWStructArray) objects[0]);
249                     return map;
250                 }
251
252             } else {
253                 URLClassLoader classLoader = getClassLoader(projectId);
254                 if (null == classLoader) {
255                     throw new RuntimeException("matlab未发布,classLoader为null");
256                 }
257                 Class<?> clazz = classLoader.loadClass(model.getClassName());
258                 Object o = clazz.newInstance();
259                 Class<?>[] paramsArray = new Class[2];
260                 paramsArray[0] = int.class;
261                 paramsArray[1] = Object[].class;
262                 Method method = clazz.getMethod(model.getMethodName(), paramsArray);
263                 classCache.put(classCacheKey, o);
264                 classMethodCache.put(classMethodCacheKey, method);
265                 Object[] objects = (Object[]) method.invoke(o, paramsValueArray);
266                 HashMap<String, Object> map = convertStructToMap((MWStructArray) objects[0]);
267                 return map;
268             }
269         }
270     }
271
272     public static HashMap<String, Object> convertStructToMap(MWStructArray struct) {
273         HashMap<String, Object> map;
274         try {
275             map = new HashMap<>();
276             String[] fieldNames = struct.fieldNames();
277             int numElements = struct.numberOfElements();
278             for (int i = 1; i <= numElements; i++) {
279                 for (String fieldName : fieldNames) {
280                     MWArray value = struct.getField(fieldName, i);
281                     Object javaValue = convertMWArrayToJavaObject(value);
282                     map.put(fieldName, javaValue);
283                 }
284             }
285         } finally {
286             struct.dispose();
287         }
288         return map;
289     }
290
291     private static Object convertMWArrayToJavaObject(MWArray mwArray) {
292         try {
293             if (mwArray instanceof MWNumericArray) {
294                 MWNumericArray numArray = (MWNumericArray) mwArray;
295                 if (numArray.numberOfElements() == 1) {
296                     MWClassID mwClassID = numArray.classID();
297                     if (mwClassID.equals(MWClassID.DOUBLE)) {
298                         return numArray.getDouble();
299                     } else if (mwClassID.equals(MWClassID.SINGLE)) {
300                         return numArray.getFloat();
301                     } else if (mwClassID.equals(MWClassID.INT8) || mwClassID.equals(MWClassID.UINT8)) {
302                         return numArray.getByte();
303                     } else if (mwClassID.equals(MWClassID.INT16) || mwClassID.equals(MWClassID.UINT16)) {
304                         return numArray.getShort();
305                     } else if (mwClassID.equals(MWClassID.INT32) || mwClassID.equals(MWClassID.UINT32)) {
306                         return numArray.getInt();
307                     } else if (mwClassID.equals(MWClassID.INT64) || mwClassID.equals(MWClassID.UINT64)) {
308                         return numArray.getLong();
309                     } else if (mwClassID.equals(MWClassID.LOGICAL)) {
310                         return numArray.getByte();
311                     }
312                     return null;
313                 } else {
314                     MWClassID mwClassID = numArray.classID();
315                     if (mwClassID.equals(MWClassID.DOUBLE)) {
316                         return numArray.toDoubleArray();
317                     } else if (mwClassID.equals(MWClassID.SINGLE)) {
318                         return numArray.toFloatArray();
319                     } else if (mwClassID.equals(MWClassID.INT8)) {
320                         return numArray.toByteArray();
321                     } else if (mwClassID.equals(MWClassID.INT16) || mwClassID.equals(MWClassID.UINT16)) {
322                         return numArray.toShortArray();
323                     } else if (mwClassID.equals(MWClassID.INT32) || mwClassID.equals(MWClassID.UINT32)) {
324                         return numArray.toIntArray();
325                     } else if (mwClassID.equals(MWClassID.INT64) || mwClassID.equals(MWClassID.UINT64)) {
326                         return numArray.toLongArray();
327                     } else if (mwClassID.equals(MWClassID.LOGICAL)) {
328                         return numArray.toByteArray();
329                     }
330                     return null;
331                 }
332             } else if (mwArray instanceof MWCharArray) {
333                 MWCharArray stringArray = (MWCharArray) mwArray;
334                 if (stringArray.numberOfElements() == 1) {
335                     return stringArray.getChar(1);
336                 } else {
337                     // 不支持string,string都是用char[]返回,所以这里将char[]转为string
338                     int[] dimensions = stringArray.getDimensions();
339                     if (dimensions[0] == 1) {
340                         //string
341                         return stringArray.toString();
342                     } else {
343                         //String[] 暂时只考虑一维string
344                         char[][] array = (char[][]) stringArray.toArray();
345                         String[] list = new String[dimensions[0]];
346                         for (int i = 0; i < dimensions[0]; i++) {
347                             list[i] = (String.valueOf(array[i]));
348                         }
349                         return list;
350                     }
351                 }
352             } else if (mwArray instanceof MWLogicalArray) {
353                 MWLogicalArray logicalArray = (MWLogicalArray) mwArray;
354                 if (logicalArray.numberOfElements() == 1) {
355                     return logicalArray.getBoolean(1);
356                 } else {
357                     int[] dimensions = logicalArray.getDimensions();
358                     return logicalArray.toArray();
359                 }
360             } else if (mwArray instanceof MWStructArray) {
361                 MWStructArray structArray = (MWStructArray) mwArray;
362                 return convertStructToMap(structArray);
363             }
364         } finally {
365             mwArray.dispose();
366         }
367         return null;
368     }
369
370     public static HashMap<String, Object> handleModelSettings(List<MlModelMethodSettingDTO> modelSettings) {
371         HashMap<String, Object> resultMap = null;
372         try {
373             resultMap = new HashMap<>(modelSettings.size());
374             for (MlModelMethodSettingDTO modelSetting : modelSettings) {
375                 switch (modelSetting.getValueType()) {
376                     case "int":
377                         resultMap.put(modelSetting.getSettingKey(), Integer.valueOf(modelSetting.getSettingValue()));
378                         break;
379                     case "string":
380                         resultMap.put(modelSetting.getSettingKey(), modelSetting.getSettingValue());
381                         break;
382                     case "decimal":
383                         resultMap.put(modelSetting.getSettingKey(), Double.valueOf(modelSetting.getSettingValue()));
384                         break;
385                     case "decimalArray":
386                         JSONArray jsonArray = JSON.parseArray(modelSetting.getSettingValue());
387                         double[] doubles = new double[jsonArray.size()];
388                         for (int i = 0; i < jsonArray.size(); i++) {
389                             doubles[i] = Double.valueOf(String.valueOf(jsonArray.get(i)));
390                         }
391                         resultMap.put(modelSetting.getSettingKey(), doubles);
392                         break;
393                 }
394             }
395         } catch (NumberFormatException e) {
396             throw new RuntimeException("模型参数有误,请检查!!!");
397         }
398         return resultMap;
399     }
400
401     public static MWStructArray convertMapToStruct(Map<String, Object> map) {
402         String[] fieldNames = map.keySet().toArray(new String[0]);
403         MWStructArray struct = new MWStructArray(1, 1, fieldNames);
404         for (String key : map.keySet()) {
405             Object value = map.get(key);
406             if (value instanceof Number) {
407                 struct.set(key, 1, ((Number) value).doubleValue());
408             } else if (value instanceof String) {
409                 struct.set(key, 1, value);
410             }
411             // 可以根据需要添加更多类型的处理
412         }
413         return struct;
414     }
415 }