package com.iailab.module.model.matlab.common.utils; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONArray; import com.iail.model.IAILModel; import com.iailab.module.model.matlab.common.exceptions.IllegalityJarException; import com.iailab.module.model.matlab.dto.MatlabJarFileClassInfoDTO; import com.iailab.module.model.matlab.dto.MatlabJarFileMethodInfoDTO; import com.iailab.module.model.matlab.dto.MlModelMethodSettingDTO; import com.iailab.module.model.mpk.common.MdkConstant; import com.mathworks.toolbox.javabuilder.*; import com.mathworks.toolbox.javabuilder.internal.MWFunctionSignature; import lombok.extern.slf4j.Slf4j; import org.springframework.util.CollectionUtils; import java.io.File; import java.io.IOException; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.net.MalformedURLException; import java.net.URL; import java.net.URLClassLoader; import java.util.*; import java.util.jar.JarEntry; import java.util.jar.JarFile; import java.util.stream.Collectors; @Slf4j public class MatlabUtils { private static HashMap classLoaderCache = new HashMap<>(); private static HashMap classCache = new HashMap<>(); private static HashMap classMethodCache = new HashMap<>(); /** * @description: 解析matlab jar文件 **/ public static List parseJarInfo(String jarFilePath, String jatName) throws IllegalityJarException { List classInfos = new ArrayList<>(); //加载jar用于解析内部class method URLClassLoader urlClassLoader = loadJar(null,jarFilePath); try (JarFile jarFile = new JarFile(jarFilePath)) { // 获取 JAR 文件中所有的条目 Enumeration entries = jarFile.entries(); while (entries.hasMoreElements()) { JarEntry entry = entries.nextElement(); String entryName = entry.getName(); // 检查条目是否为类文件 if (entryName.endsWith(".class") && !entryName.endsWith("Remote.class") && !entryName.endsWith("MCRFactory.class")) { MatlabJarFileClassInfoDTO classInfo = new MatlabJarFileClassInfoDTO(); // 将类文件的路径转换为类的全限定名 String className = entryName.replace('/', '.').substring(0, entryName.lastIndexOf(".class")); // 校验包名是否和文件名一致,不一致则说明修改过jar包名,判定为非法包 if (!className.startsWith(jatName)) { throw new IllegalityJarException(); } classInfo.setClassName(className); try { // 加载类 Class clazz = urlClassLoader.loadClass(className); // 获取该类的所有属性 Field[] fields = clazz.getDeclaredFields(); List methodInfos = new ArrayList<>(); for (Field field : fields) { int modifiers = field.getModifiers(); if (Modifier.isPrivate(modifiers) && Modifier.isStatic(modifiers) && Modifier.isFinal(modifiers)) { try { // 绕过访问控制检查 field.setAccessible(true); Object value = field.get(null); // 静态字段不需要实例,传 null if (value instanceof MWFunctionSignature) { MatlabJarFileMethodInfoDTO methodInfo = new MatlabJarFileMethodInfoDTO(); methodInfo.setMethodName(value.getClass().getField("name").get(value).toString()); // 减1是为了排除最后面的setting methodInfo.setDataLength((Integer) value.getClass().getField("numInputs").get(value) - 1); methodInfo.setOutLength((Integer) value.getClass().getField("numOutputs").get(value)); methodInfos.add(methodInfo); } } catch (Exception e) { log.error("get matlab method info exception,className:" + className, e); } } } classInfo.setMethodInfos(methodInfos); } catch (Exception e) { log.error("get matlab class info exception,className:" + className, e); } classInfos.add(classInfo); } } } catch (IOException e) { log.error("get matlab jar info exception,jarFilePath:" + jarFilePath, e); } finally { unloadJar(urlClassLoader); } return classInfos; } public static synchronized URLClassLoader loadJar(String projectId, String... jarPaths) { try { URL[] urls = new URL[jarPaths.length]; for (int i = 0; i < jarPaths.length; i++) { String jarPath = jarPaths[i]; File jarFile = new File(jarPath); if (!jarFile.exists()) { throw new RuntimeException("jar沒有找到!" + jarPath); } urls[i] = new File(jarPath).toURI().toURL(); } // 不能设置classloader的patent,让它使用双亲委派去找系统classloader中javabuilder.jar中的依赖。因为javabuilder.jar只能加载一次,加载多次会报资源占用 URLClassLoader urlClassLoader = new URLClassLoader(urls); if (projectId != null) { addClassLoaderCache(projectId, urlClassLoader); } log.info("成功加载jar包:" + String.join(";",jarPaths)); return urlClassLoader; } catch (Exception e) { throw new RuntimeException("加载jar异常", e); } } public static synchronized void unloadJar(URLClassLoader urlClassLoader) { try { urlClassLoader.close(); log.info("成功卸载jar包。"); } catch (Exception e) { throw new RuntimeException("卸载jar异常", e); } } public static synchronized void addClassLoaderCache(String projectId, URLClassLoader urlClassLoader) { classLoaderCache.put(projectId, urlClassLoader); } public static synchronized URLClassLoader getClassLoader(String projectId) { return classLoaderCache.get(projectId); } public static synchronized void removeClassLoaderCache(String projectId) { if (classLoaderCache.containsKey(projectId)) { URLClassLoader urlClassLoader = classLoaderCache.get(projectId); unloadJar(urlClassLoader); classLoaderCache.remove(projectId); removeClassCache(projectId); removeClassMethodCache(projectId); } } public static synchronized void removeClassCache(String projectId) { Iterator iterator = classCache.keySet().iterator(); while (iterator.hasNext()) { String key = iterator.next(); if (key.startsWith(projectId)) { iterator.remove(); } } } public static synchronized void removeClassMethodCache(String projectId) { Iterator iterator = classMethodCache.keySet().iterator(); while (iterator.hasNext()) { String key = iterator.next(); if (key.startsWith(projectId)) { iterator.remove(); } } } public static void removeOldFile(String bakPath, String projectId) { File dir = new File(bakPath); if (dir.exists() && dir.isDirectory()) { File[] files = dir.listFiles(); if (null != files && files.length > 0) { for (File file : files) { if (file.getName().startsWith(projectId)) { file.delete(); } } } } } /** * @description: 项目启动加载已发布的jar * @author: dzd * @date: 2024/10/10 11:58 **/ public static void loadProjectPublish(String bakPath) { File dir = new File(bakPath); if (dir.exists() && dir.isDirectory()) { File[] files = dir.listFiles(); if (null != files && files.length > 0) { HashMap> projectIdJarFilePaths = new HashMap<>(); for (File file : files) { String fileName = file.getName(); if (fileName.endsWith(".jar")) { String[] split = fileName.split(MdkConstant.SPLIT); String projectId = split[0]; if (projectId != null) { if (projectIdJarFilePaths.containsKey(projectId)) { projectIdJarFilePaths.get(projectId).add(file.getAbsolutePath()); } else { projectIdJarFilePaths.put(projectId,new ArrayList(){{add(file.getAbsolutePath());}}); } } } } try { if (!CollectionUtils.isEmpty(projectIdJarFilePaths)) { for (Map.Entry> entry : projectIdJarFilePaths.entrySet()) { // 加载新的jar loadJar(entry.getKey(),entry.getValue().toArray(new String[0])); } } } catch (Exception e) { throw new RuntimeException("加载jar异常", e); } } } } public static HashMap run(IAILModel model, Object[] paramsValueArray, String projectId) throws Exception { if (model == null) { throw new RuntimeException("模型文件不能为空!"); } else { // 上传时校验文件名和class的开头是相同的,保证className唯一,可以用 projectId_className 作为唯一key String classCacheKey = projectId + "_" + model.getClassName(); String methodParams = Arrays.stream(model.getParamsArray()).map(e -> e.getName()).collect(Collectors.joining(",")); String classMethodCacheKey = classCacheKey + "." + model.getMethodName() + "(" + methodParams + ")"; // 因为一个类下可能有多多个方法,所以这里classCacheKey 和 classMethodCacheKey 要分开判断 if (classCache.containsKey(classCacheKey)) { if (classMethodCache.containsKey(classMethodCacheKey)) { return (HashMap) classMethodCache.get(classMethodCacheKey).invoke(classCache.get(classCacheKey), paramsValueArray); } else { // 运行过这个类的其他方法,类从缓存中取,新建方法 Object o = classCache.get(classCacheKey); Class[] paramsArray = new Class[2]; paramsArray[0] = int.class; paramsArray[1] = Object[].class; Method method = o.getClass().getMethod(model.getMethodName(), paramsArray); classMethodCache.put(classMethodCacheKey, method); Object[] objects = (Object[]) method.invoke(o, paramsValueArray); HashMap map = convertStructToMap((MWStructArray) objects[0]); return map; } } else { URLClassLoader classLoader = getClassLoader(projectId); if (null == classLoader) { throw new RuntimeException("matlab未发布,classLoader为null"); } Class clazz = classLoader.loadClass(model.getClassName()); Object o = clazz.newInstance(); Class[] paramsArray = new Class[2]; paramsArray[0] = int.class; paramsArray[1] = Object[].class; Method method = clazz.getMethod(model.getMethodName(), paramsArray); classCache.put(classCacheKey, o); classMethodCache.put(classMethodCacheKey, method); Object[] objects = (Object[]) method.invoke(o, paramsValueArray); HashMap map = convertStructToMap((MWStructArray) objects[0]); return map; } } } public static HashMap convertStructToMap(MWStructArray struct) { HashMap map; try { map = new HashMap<>(); String[] fieldNames = struct.fieldNames(); int numElements = struct.numberOfElements(); for (int i = 1; i <= numElements; i++) { for (String fieldName : fieldNames) { MWArray value = struct.getField(fieldName, i); Object javaValue = convertMWArrayToJavaObject(value); map.put(fieldName, javaValue); } } } finally { struct.dispose(); } return map; } private static Object convertMWArrayToJavaObject(MWArray mwArray) { try { if (mwArray instanceof MWNumericArray) { MWNumericArray numArray = (MWNumericArray) mwArray; if (numArray.numberOfElements() == 1) { MWClassID mwClassID = numArray.classID(); if (mwClassID.equals(MWClassID.DOUBLE)) { return numArray.getDouble(); } else if (mwClassID.equals(MWClassID.SINGLE)) { return numArray.getFloat(); } else if (mwClassID.equals(MWClassID.INT8) || mwClassID.equals(MWClassID.UINT8)) { return numArray.getByte(); } else if (mwClassID.equals(MWClassID.INT16) || mwClassID.equals(MWClassID.UINT16)) { return numArray.getShort(); } else if (mwClassID.equals(MWClassID.INT32) || mwClassID.equals(MWClassID.UINT32)) { return numArray.getInt(); } else if (mwClassID.equals(MWClassID.INT64) || mwClassID.equals(MWClassID.UINT64)) { return numArray.getLong(); } else if (mwClassID.equals(MWClassID.LOGICAL)) { return numArray.getByte(); } return null; } else { MWClassID mwClassID = numArray.classID(); if (mwClassID.equals(MWClassID.DOUBLE)) { return numArray.toDoubleArray(); } else if (mwClassID.equals(MWClassID.SINGLE)) { return numArray.toFloatArray(); } else if (mwClassID.equals(MWClassID.INT8)) { return numArray.toByteArray(); } else if (mwClassID.equals(MWClassID.INT16) || mwClassID.equals(MWClassID.UINT16)) { return numArray.toShortArray(); } else if (mwClassID.equals(MWClassID.INT32) || mwClassID.equals(MWClassID.UINT32)) { return numArray.toIntArray(); } else if (mwClassID.equals(MWClassID.INT64) || mwClassID.equals(MWClassID.UINT64)) { return numArray.toLongArray(); } else if (mwClassID.equals(MWClassID.LOGICAL)) { return numArray.toByteArray(); } return null; } } else if (mwArray instanceof MWCharArray) { MWCharArray stringArray = (MWCharArray) mwArray; if (stringArray.numberOfElements() == 1) { return stringArray.getChar(1); } else { // 不支持string,string都是用char[]返回,所以这里将char[]转为string int[] dimensions = stringArray.getDimensions(); if (dimensions[0] == 1) { //string return stringArray.toString(); } else { //String[] 暂时只考虑一维string char[][] array = (char[][]) stringArray.toArray(); String[] list = new String[dimensions[0]]; for (int i = 0; i < dimensions[0]; i++) { list[i] = (String.valueOf(array[i])); } return list; } } } else if (mwArray instanceof MWLogicalArray) { MWLogicalArray logicalArray = (MWLogicalArray) mwArray; if (logicalArray.numberOfElements() == 1) { return logicalArray.getBoolean(1); } else { int[] dimensions = logicalArray.getDimensions(); return logicalArray.toArray(); } } else if (mwArray instanceof MWStructArray) { MWStructArray structArray = (MWStructArray) mwArray; return convertStructToMap(structArray); } } finally { mwArray.dispose(); } return null; } public static HashMap handleModelSettings(List modelSettings) { HashMap resultMap = null; try { resultMap = new HashMap<>(modelSettings.size()); for (MlModelMethodSettingDTO modelSetting : modelSettings) { switch (modelSetting.getValueType()) { case "int": resultMap.put(modelSetting.getSettingKey(), Integer.valueOf(modelSetting.getSettingValue())); break; case "string": resultMap.put(modelSetting.getSettingKey(), modelSetting.getSettingValue()); break; case "decimal": resultMap.put(modelSetting.getSettingKey(), Double.valueOf(modelSetting.getSettingValue())); break; case "decimalArray": JSONArray jsonArray = JSON.parseArray(modelSetting.getSettingValue()); double[] doubles = new double[jsonArray.size()]; for (int i = 0; i < jsonArray.size(); i++) { doubles[i] = Double.valueOf(String.valueOf(jsonArray.get(i))); } resultMap.put(modelSetting.getSettingKey(), doubles); break; } } } catch (NumberFormatException e) { throw new RuntimeException("模型参数有误,请检查!!!"); } return resultMap; } public static MWStructArray convertMapToStruct(Map map) { String[] fieldNames = map.keySet().toArray(new String[0]); MWStructArray struct = new MWStructArray(1, 1, fieldNames); for (String key : map.keySet()) { Object value = map.get(key); if (value instanceof Number) { struct.set(key, 1, ((Number) value).doubleValue()); } else if (value instanceof String) { struct.set(key, 1, value); } // 可以根据需要添加更多类型的处理 } return struct; } }