package com.iailab.module.model.matlab.common.utils;
|
|
import com.alibaba.fastjson.JSON;
|
import com.alibaba.fastjson.JSONArray;
|
import com.iail.model.IAILModel;
|
import com.iailab.module.model.matlab.common.exceptions.IllegalityJarException;
|
import com.iailab.module.model.matlab.dto.MatlabJarFileClassInfoDTO;
|
import com.iailab.module.model.matlab.dto.MatlabJarFileMethodInfoDTO;
|
import com.iailab.module.model.matlab.dto.MlModelMethodSettingDTO;
|
import com.iailab.module.model.mpk.common.MdkConstant;
|
import com.mathworks.toolbox.javabuilder.*;
|
import com.mathworks.toolbox.javabuilder.internal.MWFunctionSignature;
|
import lombok.extern.slf4j.Slf4j;
|
|
import java.io.File;
|
import java.io.IOException;
|
import java.lang.reflect.Field;
|
import java.lang.reflect.Method;
|
import java.lang.reflect.Modifier;
|
import java.net.MalformedURLException;
|
import java.net.URL;
|
import java.net.URLClassLoader;
|
import java.util.*;
|
import java.util.jar.JarEntry;
|
import java.util.jar.JarFile;
|
import java.util.stream.Collectors;
|
|
@Slf4j
|
public class MatlabUtils {
|
|
private static HashMap<String, URLClassLoader> classLoaderCache = new HashMap<>();
|
private static HashMap<String, Object> classCache = new HashMap<>();
|
private static HashMap<String, Method> classMethodCache = new HashMap<>();
|
|
/**
|
* @description: 解析matlab jar文件
|
**/
|
public static List<MatlabJarFileClassInfoDTO> parseJarInfo(String jarFilePath, String jatName) throws IllegalityJarException {
|
List<MatlabJarFileClassInfoDTO> classInfos = new ArrayList<>();
|
//加载jar用于解析内部class method
|
URLClassLoader urlClassLoader = loadJar(null,jarFilePath);
|
try (JarFile jarFile = new JarFile(jarFilePath)) {
|
// 获取 JAR 文件中所有的条目
|
Enumeration<JarEntry> entries = jarFile.entries();
|
while (entries.hasMoreElements()) {
|
JarEntry entry = entries.nextElement();
|
String entryName = entry.getName();
|
// 检查条目是否为类文件
|
if (entryName.endsWith(".class") && !entryName.endsWith("Remote.class") && !entryName.endsWith("MCRFactory.class")) {
|
MatlabJarFileClassInfoDTO classInfo = new MatlabJarFileClassInfoDTO();
|
// 将类文件的路径转换为类的全限定名
|
String className = entryName.replace('/', '.').substring(0, entryName.lastIndexOf(".class"));
|
// 校验包名是否和文件名一致,不一致则说明修改过jar包名,判定为非法包
|
if (!className.startsWith(jatName)) {
|
throw new IllegalityJarException();
|
}
|
|
classInfo.setClassName(className);
|
try {
|
// 加载类
|
Class<?> clazz = urlClassLoader.loadClass(className);
|
// 获取该类的所有属性
|
Field[] fields = clazz.getDeclaredFields();
|
List<MatlabJarFileMethodInfoDTO> methodInfos = new ArrayList<>();
|
for (Field field : fields) {
|
int modifiers = field.getModifiers();
|
if (Modifier.isPrivate(modifiers) && Modifier.isStatic(modifiers) && Modifier.isFinal(modifiers)) {
|
try {
|
// 绕过访问控制检查
|
field.setAccessible(true);
|
Object value = field.get(null); // 静态字段不需要实例,传 null
|
if (value instanceof MWFunctionSignature) {
|
MatlabJarFileMethodInfoDTO methodInfo = new MatlabJarFileMethodInfoDTO();
|
methodInfo.setMethodName(value.getClass().getField("name").get(value).toString());
|
// 减1是为了排除最后面的setting
|
methodInfo.setDataLength((Integer) value.getClass().getField("numInputs").get(value) - 1);
|
methodInfo.setOutLength((Integer) value.getClass().getField("numOutputs").get(value));
|
methodInfos.add(methodInfo);
|
}
|
} catch (Exception e) {
|
log.error("get matlab method info exception,className:" + className, e);
|
}
|
}
|
}
|
classInfo.setMethodInfos(methodInfos);
|
} catch (Exception e) {
|
log.error("get matlab class info exception,className:" + className, e);
|
}
|
classInfos.add(classInfo);
|
}
|
}
|
} catch (IOException e) {
|
log.error("get matlab jar info exception,jarFilePath:" + jarFilePath, e);
|
} finally {
|
unloadJar(urlClassLoader);
|
}
|
return classInfos;
|
}
|
|
public static synchronized URLClassLoader loadJar(String projectId, String... jarPaths) {
|
try {
|
URL[] urls = new URL[jarPaths.length];
|
for (int i = 0; i < jarPaths.length; i++) {
|
String jarPath = jarPaths[i];
|
File jarFile = new File(jarPath);
|
if (!jarFile.exists()) {
|
throw new RuntimeException("jar沒有找到!" + jarPath);
|
}
|
urls[i] = new File(jarPath).toURI().toURL();
|
}
|
|
// 不能设置classloader的patent,让它使用双亲委派去找系统classloader中javabuilder.jar中的依赖。因为javabuilder.jar只能加载一次,加载多次会报资源占用
|
URLClassLoader urlClassLoader = new URLClassLoader(urls);
|
if (projectId != null) {
|
addClassLoaderCache(projectId, urlClassLoader);
|
}
|
log.info("成功加载jar包:" + String.join(";",jarPaths));
|
return urlClassLoader;
|
} catch (Exception e) {
|
throw new RuntimeException("加载jar异常", e);
|
}
|
}
|
|
public static synchronized void unloadJar(URLClassLoader urlClassLoader) {
|
try {
|
urlClassLoader.close();
|
log.info("成功卸载jar包。");
|
} catch (Exception e) {
|
throw new RuntimeException("卸载jar异常", e);
|
}
|
}
|
|
public static synchronized void addClassLoaderCache(String projectId, URLClassLoader urlClassLoader) {
|
classLoaderCache.put(projectId, urlClassLoader);
|
}
|
|
public static synchronized URLClassLoader getClassLoader(String projectId) {
|
return classLoaderCache.get(projectId);
|
}
|
|
public static synchronized void removeClassLoaderCache(String projectId) {
|
if (classLoaderCache.containsKey(projectId)) {
|
URLClassLoader urlClassLoader = classLoaderCache.get(projectId);
|
unloadJar(urlClassLoader);
|
classLoaderCache.remove(projectId);
|
removeClassCache(projectId);
|
removeClassMethodCache(projectId);
|
}
|
}
|
|
public static synchronized void removeClassCache(String projectId) {
|
Iterator<String> iterator = classCache.keySet().iterator();
|
while (iterator.hasNext()) {
|
String key = iterator.next();
|
if (key.startsWith(projectId)) {
|
iterator.remove();
|
}
|
}
|
}
|
|
public static synchronized void removeClassMethodCache(String projectId) {
|
Iterator<String> iterator = classMethodCache.keySet().iterator();
|
while (iterator.hasNext()) {
|
String key = iterator.next();
|
if (key.startsWith(projectId)) {
|
iterator.remove();
|
}
|
}
|
}
|
|
public static void removeOldFile(String bakPath, String projectId) {
|
File dir = new File(bakPath);
|
if (dir.exists() && dir.isDirectory()) {
|
File[] files = dir.listFiles();
|
if (null != files && files.length > 0) {
|
for (File file : files) {
|
if (file.getName().startsWith(projectId)) {
|
file.delete();
|
}
|
}
|
}
|
}
|
}
|
|
/**
|
* @description: 项目启动加载已发布的dll和jar
|
* @author: dzd
|
* @date: 2024/10/10 11:58
|
**/
|
public static void loadProjectPublish(String bakPath) {
|
File dir = new File(bakPath);
|
if (dir.exists() && dir.isDirectory()) {
|
File[] files = dir.listFiles();
|
if (null != files && files.length > 0) {
|
for (File file : files) {
|
String fileName = file.getName();
|
if (fileName.endsWith(".jar")) {
|
String[] split = fileName.split(MdkConstant.SPLIT);
|
String projectId = split[0];
|
|
URLClassLoader urlClassLoader = null;
|
try {
|
// 加载新的jar
|
urlClassLoader = loadJar(projectId,file.getAbsolutePath());
|
} catch (Exception e) {
|
throw new RuntimeException("加载jar异常", e);
|
}
|
// 成功后加入缓存
|
addClassLoaderCache(projectId, urlClassLoader);
|
}
|
}
|
}
|
}
|
|
}
|
|
public static HashMap<String, Object> run(IAILModel model, Object[] paramsValueArray, String projectId) throws Exception {
|
if (model == null) {
|
throw new RuntimeException("模型文件不能为空!");
|
} else {
|
// 上传时校验文件名和class的开头是相同的,保证className唯一,可以用 projectId_className 作为唯一key
|
String classCacheKey = projectId + "_" + model.getClassName();
|
String methodParams = Arrays.stream(model.getParamsArray()).map(e -> e.getName()).collect(Collectors.joining(","));
|
String classMethodCacheKey = classCacheKey + "." + model.getMethodName() + "(" + methodParams + ")";
|
// 因为一个类下可能有多多个方法,所以这里classCacheKey 和 classMethodCacheKey 要分开判断
|
if (classCache.containsKey(classCacheKey)) {
|
if (classMethodCache.containsKey(classMethodCacheKey)) {
|
return (HashMap) classMethodCache.get(classMethodCacheKey).invoke(classCache.get(classCacheKey), paramsValueArray);
|
} else {
|
// 运行过这个类的其他方法,类从缓存中取,新建方法
|
Object o = classCache.get(classCacheKey);
|
Class<?>[] paramsArray = new Class[2];
|
paramsArray[0] = int.class;
|
paramsArray[1] = Object[].class;
|
Method method = o.getClass().getMethod(model.getMethodName(), paramsArray);
|
classMethodCache.put(classMethodCacheKey, method);
|
Object[] objects = (Object[]) method.invoke(o, paramsValueArray);
|
HashMap<String, Object> map = convertStructToMap((MWStructArray) objects[0]);
|
return map;
|
}
|
|
} else {
|
URLClassLoader classLoader = getClassLoader(projectId);
|
if (null == classLoader) {
|
throw new RuntimeException("matlab未发布,classLoader为null");
|
}
|
Class<?> clazz = classLoader.loadClass(model.getClassName());
|
Object o = clazz.newInstance();
|
Class<?>[] paramsArray = new Class[2];
|
paramsArray[0] = int.class;
|
paramsArray[1] = Object[].class;
|
Method method = clazz.getMethod(model.getMethodName(), paramsArray);
|
classCache.put(classCacheKey, o);
|
classMethodCache.put(classMethodCacheKey, method);
|
Object[] objects = (Object[]) method.invoke(o, paramsValueArray);
|
HashMap<String, Object> map = convertStructToMap((MWStructArray) objects[0]);
|
return map;
|
}
|
}
|
}
|
|
public static HashMap<String, Object> convertStructToMap(MWStructArray struct) {
|
HashMap<String, Object> map;
|
try {
|
map = new HashMap<>();
|
String[] fieldNames = struct.fieldNames();
|
int numElements = struct.numberOfElements();
|
for (int i = 1; i <= numElements; i++) {
|
for (String fieldName : fieldNames) {
|
MWArray value = struct.getField(fieldName, i);
|
Object javaValue = convertMWArrayToJavaObject(value);
|
map.put(fieldName, javaValue);
|
}
|
}
|
} finally {
|
struct.dispose();
|
}
|
return map;
|
}
|
|
private static Object convertMWArrayToJavaObject(MWArray mwArray) {
|
try {
|
if (mwArray instanceof MWNumericArray) {
|
MWNumericArray numArray = (MWNumericArray) mwArray;
|
if (numArray.numberOfElements() == 1) {
|
MWClassID mwClassID = numArray.classID();
|
if (mwClassID.equals(MWClassID.DOUBLE)) {
|
return numArray.getDouble();
|
} else if (mwClassID.equals(MWClassID.SINGLE)) {
|
return numArray.getFloat();
|
} else if (mwClassID.equals(MWClassID.INT8) || mwClassID.equals(MWClassID.UINT8)) {
|
return numArray.getByte();
|
} else if (mwClassID.equals(MWClassID.INT16) || mwClassID.equals(MWClassID.UINT16)) {
|
return numArray.getShort();
|
} else if (mwClassID.equals(MWClassID.INT32) || mwClassID.equals(MWClassID.UINT32)) {
|
return numArray.getInt();
|
} else if (mwClassID.equals(MWClassID.INT64) || mwClassID.equals(MWClassID.UINT64)) {
|
return numArray.getLong();
|
} else if (mwClassID.equals(MWClassID.LOGICAL)) {
|
return numArray.getByte();
|
}
|
return null;
|
} else {
|
MWClassID mwClassID = numArray.classID();
|
if (mwClassID.equals(MWClassID.DOUBLE)) {
|
return numArray.toDoubleArray();
|
} else if (mwClassID.equals(MWClassID.SINGLE)) {
|
return numArray.toFloatArray();
|
} else if (mwClassID.equals(MWClassID.INT8)) {
|
return numArray.toByteArray();
|
} else if (mwClassID.equals(MWClassID.INT16) || mwClassID.equals(MWClassID.UINT16)) {
|
return numArray.toShortArray();
|
} else if (mwClassID.equals(MWClassID.INT32) || mwClassID.equals(MWClassID.UINT32)) {
|
return numArray.toIntArray();
|
} else if (mwClassID.equals(MWClassID.INT64) || mwClassID.equals(MWClassID.UINT64)) {
|
return numArray.toLongArray();
|
} else if (mwClassID.equals(MWClassID.LOGICAL)) {
|
return numArray.toByteArray();
|
}
|
return null;
|
}
|
} else if (mwArray instanceof MWCharArray) {
|
MWCharArray stringArray = (MWCharArray) mwArray;
|
if (stringArray.numberOfElements() == 1) {
|
return stringArray.getChar(1);
|
} else {
|
// 不支持string,string都是用char[]返回,所以这里将char[]转为string
|
int[] dimensions = stringArray.getDimensions();
|
if (dimensions[0] == 1) {
|
//string
|
return stringArray.toString();
|
} else {
|
//String[] 暂时只考虑一维string
|
char[][] array = (char[][]) stringArray.toArray();
|
String[] list = new String[dimensions[0]];
|
for (int i = 0; i < dimensions[0]; i++) {
|
list[i] = (String.valueOf(array[i]));
|
}
|
return list;
|
}
|
}
|
} else if (mwArray instanceof MWLogicalArray) {
|
MWLogicalArray logicalArray = (MWLogicalArray) mwArray;
|
if (logicalArray.numberOfElements() == 1) {
|
return logicalArray.getBoolean(1);
|
} else {
|
int[] dimensions = logicalArray.getDimensions();
|
return logicalArray.toArray();
|
}
|
} else if (mwArray instanceof MWStructArray) {
|
MWStructArray structArray = (MWStructArray) mwArray;
|
return convertStructToMap(structArray);
|
}
|
} finally {
|
mwArray.dispose();
|
}
|
return null;
|
}
|
|
public static HashMap<String, Object> handleModelSettings(List<MlModelMethodSettingDTO> modelSettings) {
|
HashMap<String, Object> resultMap = null;
|
try {
|
resultMap = new HashMap<>(modelSettings.size());
|
for (MlModelMethodSettingDTO modelSetting : modelSettings) {
|
switch (modelSetting.getValueType()) {
|
case "int":
|
resultMap.put(modelSetting.getSettingKey(), Integer.valueOf(modelSetting.getSettingValue()));
|
break;
|
case "string":
|
resultMap.put(modelSetting.getSettingKey(), modelSetting.getSettingValue());
|
break;
|
case "decimal":
|
resultMap.put(modelSetting.getSettingKey(), Double.valueOf(modelSetting.getSettingValue()));
|
break;
|
case "decimalArray":
|
JSONArray jsonArray = JSON.parseArray(modelSetting.getSettingValue());
|
double[] doubles = new double[jsonArray.size()];
|
for (int i = 0; i < jsonArray.size(); i++) {
|
doubles[i] = Double.valueOf(String.valueOf(jsonArray.get(i)));
|
}
|
resultMap.put(modelSetting.getSettingKey(), doubles);
|
break;
|
}
|
}
|
} catch (NumberFormatException e) {
|
throw new RuntimeException("模型参数有误,请检查!!!");
|
}
|
return resultMap;
|
}
|
|
public static MWStructArray convertMapToStruct(Map<String, Object> map) {
|
String[] fieldNames = map.keySet().toArray(new String[0]);
|
MWStructArray struct = new MWStructArray(1, 1, fieldNames);
|
for (String key : map.keySet()) {
|
Object value = map.get(key);
|
if (value instanceof Number) {
|
struct.set(key, 1, ((Number) value).doubleValue());
|
} else if (value instanceof String) {
|
struct.set(key, 1, value);
|
}
|
// 可以根据需要添加更多类型的处理
|
}
|
return struct;
|
}
|
}
|