package com.iailab.module.model.mpk.common.utils; import cn.hutool.core.io.FileUtil; import com.iail.model.IAILModel; import com.iail.utils.RSAUtils; import com.iailab.module.model.mpk.common.MdkConstant; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import java.io.File; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.net.URL; import java.net.URLClassLoader; import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; import java.util.Vector; import java.util.stream.Collectors; @Slf4j public class DllUtils { private static HashMap classLoaderCache = new HashMap<>(); private static HashMap classCache = new HashMap<>(); private static HashMap classMethodCache = new HashMap<>(); /** * @description: 加载dll到指定class下 * @author: dzd * @date: 2024/9/30 14:27 **/ public static void loadDll(Class clazz, String dllPath) { try { Method method = Runtime.class.getDeclaredMethod("load0", Class.class, String.class); boolean accessible = method.isAccessible(); method.setAccessible(true); method.invoke(Runtime.getRuntime(), clazz, dllPath); method.setAccessible(accessible); log.info("成功加载dll:" + dllPath); } catch (Exception e) { throw new RuntimeException("加载dll异常", e); } } /** * @description: 卸载classLoader下全部dll * @author: dzd * @date: 2024/9/30 14:31 **/ public static synchronized void unloadDll(URLClassLoader classLoader) { try { Field field = ClassLoader.class.getDeclaredField("nativeLibraries"); field.setAccessible(true); Vector libs = (Vector) field.get(classLoader); Iterator it = libs.iterator(); Object o; while (it.hasNext()) { o = it.next(); Method method = o.getClass().getDeclaredMethod("finalize"); boolean accessible = method.isAccessible(); method.setAccessible(true); method.invoke(o); method.setAccessible(accessible); Field nameDield = o.getClass().getDeclaredField("name"); nameDield.setAccessible(true); String name = (String) nameDield.get(o); log.info("成功卸载dll:" + name); } } catch (Exception e) { throw new RuntimeException("卸载dll异常", e); } } /** * @description: 从classLoader中卸载dll,如果dllName传null,则默认删除全部dll * @author: dzd * @date: 2024/9/30 14:52 **/ public static synchronized void unloadDllName(URLClassLoader classLoader, String dllName) { try { Field field = ClassLoader.class.getDeclaredField("nativeLibraries"); field.setAccessible(true); Vector libs = (Vector) field.get(classLoader); Iterator it = libs.iterator(); Object o; while (it.hasNext()) { o = it.next(); Field nameDield = o.getClass().getDeclaredField("name"); nameDield.setAccessible(true); String name = (String) nameDield.get(o); // dllName不为null 并且 不等于name,跳出(dllName为null默认全部删除) if (StringUtils.isNotEmpty(dllName) && !dllName.equals(name)) { return; } Method method = o.getClass().getDeclaredMethod("finalize"); boolean accessible = method.isAccessible(); method.setAccessible(true); method.invoke(o); method.setAccessible(accessible); log.info("成功卸载dll:" + name); } } catch (Exception e) { throw new RuntimeException("卸载dll异常", e); } } /** * @description: 加载jar到特定的URLClassLoader,并返回URLClassLoader * @author: dzd * @date: 2024/9/30 14:20 **/ public static synchronized URLClassLoader loadJar(String jarPath) { File jarFile = new File(jarPath); if (!jarFile.exists()) { throw new RuntimeException("jar沒有找到!"+jarPath); } else { try { // 设置classloader的patent为null,限制使用双亲委派,防止其他classloader找到class,导致dll加载到其他classloader URLClassLoader urlClassLoader = new URLClassLoader(new URL[]{jarFile.toURI().toURL()},null,null); log.info("成功加载jar包:" + jarFile.getAbsolutePath()); return urlClassLoader; } catch (Exception e) { throw new RuntimeException("加载jar异常", e); } } } public static synchronized void unloadJar(URLClassLoader urlClassLoader) { try { urlClassLoader.close(); log.info("成功卸载jar包。"); } catch (Exception e) { throw new RuntimeException("卸载jar异常", e); } } public static synchronized void addClassLoaderCache(String projectId, URLClassLoader urlClassLoader) { classLoaderCache.put(projectId, urlClassLoader); } public static synchronized URLClassLoader getClassLoader(String projectId) { return classLoaderCache.get(projectId); } public static synchronized void removeClassLoaderCache(String projectId) { if (classLoaderCache.containsKey(projectId)) { URLClassLoader urlClassLoader = classLoaderCache.get(projectId); unloadDll(urlClassLoader); unloadJar(urlClassLoader); classLoaderCache.remove(projectId); removeClassCache(projectId); removeClassMethodCache(projectId); } } public static synchronized void removeClassCache(String projectId) { for (String key : classCache.keySet()) { if (key.startsWith(projectId)) { classCache.remove(key); } } } public static synchronized void removeClassMethodCache(String projectId) { for (String key : classMethodCache.keySet()) { if (key.startsWith(projectId)) { classMethodCache.remove(key); } } } public static void removeOldFile(String bakPath,String projectId) { File dir = new File(bakPath); if (dir.exists() && dir.isDirectory()) { File[] files = dir.listFiles(); if (null != files && files.length > 0) { for (File file : files) { if (file.getName().startsWith(projectId)) { file.delete(); } } } } } /** * @description: 项目启动加载已发布的dll和jar * @author: dzd * @date: 2024/10/10 11:58 **/ public static void loadProjectPublish(String bakPath) { File dir = new File(bakPath); if (dir.exists() && dir.isDirectory()) { File[] files = dir.listFiles(); if (null != files && files.length > 0) { for (File file : files) { String fileName = file.getName(); if (fileName.endsWith(".jar")) { String[] split = fileName.substring(0,fileName.length() - 4).split(MdkConstant.SPLIT); String projectId = split[0]; String historyId = split[1]; String jarFilePath = bakPath + File.separator + projectId + MdkConstant.SPLIT + historyId + ".jar"; String dllFilePath = bakPath + File.separator + projectId + MdkConstant.SPLIT + historyId + ".dll"; if (FileUtil.exist(jarFilePath) && FileUtil.exist(dllFilePath)) { URLClassLoader urlClassLoader = null; try { // 加载新的jar urlClassLoader = loadJar(jarFilePath); } catch (Exception e) { throw new RuntimeException("加载jar异常",e); } try { // 加载新的dll loadDll(urlClassLoader.loadClass("iail.mdk.model.common.Environment"),dllFilePath); } catch (Exception e) { unloadJar(urlClassLoader); throw new RuntimeException("加载dll异常",e); } // 都加载成功后加入缓存 addClassLoaderCache(projectId,urlClassLoader); } } } } } } public static HashMap run(IAILModel model, Object[] paramsValueArray, String projectId) throws Exception { if (RSAUtils.checkLisenceBean().getCode() != 1) { throw new SecurityException("Lisence 不可用!"); } else if (model == null) { throw new RuntimeException("模型文件不能为空!"); } else { String classCacheKey = projectId + "_" + model.getClassName(); String methodParams = Arrays.stream(model.getParamsArray()).map(e -> e.getName()).collect(Collectors.joining(",")); String classMethodCacheKey = classCacheKey + "." + model.getMethodName() + "(" + methodParams + ")"; if (classCache.containsKey(classCacheKey) && classMethodCache.containsKey(classMethodCacheKey)) { return (HashMap)classMethodCache.get(classMethodCacheKey).invoke( classCache.get(classCacheKey), paramsValueArray); }else { URLClassLoader classLoader = DllUtils.getClassLoader(projectId); if (null == classLoader) { throw new RuntimeException("dll未发布,classLoader为null"); } Class clazz = classLoader.loadClass(model.getClassName()); Object o = clazz.newInstance(); Method method = clazz.getMethod(model.getMethodName(), model.getParamsArray()); classCache.put(classCacheKey,o); classMethodCache.put(classMethodCacheKey,method); return (HashMap)method.invoke(o, paramsValueArray); } } } }