package com.iailab.module.model.matlab.common.utils;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.iail.model.IAILModel;
import com.iailab.module.model.matlab.common.exceptions.IllegalityJarException;
import com.iailab.module.model.matlab.dto.MatlabJarFileClassInfoDTO;
import com.iailab.module.model.matlab.dto.MatlabJarFileMethodInfoDTO;
import com.iailab.module.model.matlab.dto.MlModelMethodSettingDTO;
import com.iailab.module.model.mpk.common.MdkConstant;
import com.mathworks.toolbox.javabuilder.*;
import com.mathworks.toolbox.javabuilder.internal.MWFunctionSignature;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;

import java.io.File;
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.*;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.stream.Collectors;

@Slf4j
public class MatlabUtils {

    private static HashMap<String, URLClassLoader> classLoaderCache = new HashMap<>();
    private static HashMap<String, Object> classCache = new HashMap<>();
    private static HashMap<String, Method> classMethodCache = new HashMap<>();

    /**
     * @description: 解析matlab jar文件
     **/
    public static List<MatlabJarFileClassInfoDTO> parseJarInfo(String jarFilePath, String jatName) throws IllegalityJarException {
        List<MatlabJarFileClassInfoDTO> classInfos = new ArrayList<>();
        //加载jar用于解析内部class method
        URLClassLoader urlClassLoader = loadJar(null,jarFilePath);
        try (JarFile jarFile = new JarFile(jarFilePath)) {
            // 获取 JAR 文件中所有的条目
            Enumeration<JarEntry> entries = jarFile.entries();
            while (entries.hasMoreElements()) {
                JarEntry entry = entries.nextElement();
                String entryName = entry.getName();
                // 检查条目是否为类文件
                if (entryName.endsWith(".class") && !entryName.endsWith("Remote.class") && !entryName.endsWith("MCRFactory.class")) {
                    MatlabJarFileClassInfoDTO classInfo = new MatlabJarFileClassInfoDTO();
                    // 将类文件的路径转换为类的全限定名
                    String className = entryName.replace('/', '.').substring(0, entryName.lastIndexOf(".class"));
                    // 校验包名是否和文件名一致,不一致则说明修改过jar包名,判定为非法包
                    if (!className.startsWith(jatName)) {
                        throw new IllegalityJarException();
                    }

                    classInfo.setClassName(className);
                    try {
                        // 加载类
                        Class<?> clazz = urlClassLoader.loadClass(className);
                        // 获取该类的所有属性
                        Field[] fields = clazz.getDeclaredFields();
                        List<MatlabJarFileMethodInfoDTO> methodInfos = new ArrayList<>();
                        for (Field field : fields) {
                            int modifiers = field.getModifiers();
                            if (Modifier.isPrivate(modifiers) && Modifier.isStatic(modifiers) && Modifier.isFinal(modifiers)) {
                                try {
                                    // 绕过访问控制检查
                                    field.setAccessible(true);
                                    Object value = field.get(null); // 静态字段不需要实例,传 null
                                    if (value instanceof MWFunctionSignature) {
                                        MatlabJarFileMethodInfoDTO methodInfo = new MatlabJarFileMethodInfoDTO();
                                        methodInfo.setMethodName(value.getClass().getField("name").get(value).toString());
                                        // 减1是为了排除最后面的setting
                                        methodInfo.setDataLength((Integer) value.getClass().getField("numInputs").get(value) - 1);
                                        methodInfo.setOutLength((Integer) value.getClass().getField("numOutputs").get(value));
                                        methodInfos.add(methodInfo);
                                    }
                                } catch (Exception e) {
                                    log.error("get matlab method info exception,className:" + className, e);
                                }
                            }
                        }
                        classInfo.setMethodInfos(methodInfos);
                    } catch (Exception e) {
                        log.error("get matlab class info exception,className:" + className, e);
                    }
                    classInfos.add(classInfo);
                }
            }
        } catch (IOException e) {
            log.error("get matlab jar info exception,jarFilePath:" + jarFilePath, e);
        } finally {
            unloadJar(urlClassLoader);
        }
        return classInfos;
    }

    public static synchronized URLClassLoader loadJar(String projectId, String... jarPaths) {
        try {
            URL[] urls = new URL[jarPaths.length];
            for (int i = 0; i < jarPaths.length; i++) {
                String jarPath = jarPaths[i];
                File jarFile = new File(jarPath);
                if (!jarFile.exists()) {
                    throw new RuntimeException("jar沒有找到!" + jarPath);
                }
                urls[i] = new File(jarPath).toURI().toURL();
            }

            // 不能设置classloader的patent,让它使用双亲委派去找系统classloader中javabuilder.jar中的依赖。因为javabuilder.jar只能加载一次,加载多次会报资源占用
            URLClassLoader urlClassLoader = new URLClassLoader(urls);
            if (projectId != null) {
                addClassLoaderCache(projectId, urlClassLoader);
            }
            log.info("成功加载jar包:" + String.join(";",jarPaths));
            return urlClassLoader;
        } catch (Exception e) {
            throw new RuntimeException("加载jar异常", e);
        }
    }

    public static synchronized void unloadJar(URLClassLoader urlClassLoader) {
        try {
            urlClassLoader.close();
            log.info("成功卸载jar包。");
        } catch (Exception e) {
            throw new RuntimeException("卸载jar异常", e);
        }
    }

    public static synchronized void addClassLoaderCache(String projectId, URLClassLoader urlClassLoader) {
        classLoaderCache.put(projectId, urlClassLoader);
    }

    public static synchronized URLClassLoader getClassLoader(String projectId) {
        return classLoaderCache.get(projectId);
    }

    public static synchronized void removeClassLoaderCache(String projectId) {
        if (classLoaderCache.containsKey(projectId)) {
            URLClassLoader urlClassLoader = classLoaderCache.get(projectId);
            unloadJar(urlClassLoader);
            classLoaderCache.remove(projectId);
            removeClassCache(projectId);
            removeClassMethodCache(projectId);
        }
    }

    public static synchronized void removeClassCache(String projectId) {
        Iterator<String> iterator = classCache.keySet().iterator();
        while (iterator.hasNext()) {
            String key = iterator.next();
            if (key.startsWith(projectId)) {
                iterator.remove();
            }
        }
    }

    public static synchronized void removeClassMethodCache(String projectId) {
        Iterator<String> iterator = classMethodCache.keySet().iterator();
        while (iterator.hasNext()) {
            String key = iterator.next();
            if (key.startsWith(projectId)) {
                iterator.remove();
            }
        }
    }

    public static void removeOldFile(String bakPath, String projectId) {
        File dir = new File(bakPath);
        if (dir.exists() && dir.isDirectory()) {
            File[] files = dir.listFiles();
            if (null != files && files.length > 0) {
                for (File file : files) {
                    if (file.getName().startsWith(projectId)) {
                        file.delete();
                    }
                }
            }
        }
    }

    /**
     * @description: 项目启动加载已发布的jar
     * @author: dzd
     * @date: 2024/10/10 11:58
     **/
    public static void loadProjectPublish(String bakPath) {
        File dir = new File(bakPath);
        if (dir.exists() && dir.isDirectory()) {
            File[] files = dir.listFiles();
            if (null != files && files.length > 0) {
                HashMap<String,List<String>> projectIdJarFilePaths = new HashMap<>();
                for (File file : files) {
                    String fileName = file.getName();
                    if (fileName.endsWith(".jar")) {
                        String[] split = fileName.split(MdkConstant.SPLIT);
                        String projectId = split[0];
                        if (projectId != null) {
                            if (projectIdJarFilePaths.containsKey(projectId)) {
                                projectIdJarFilePaths.get(projectId).add(file.getAbsolutePath());
                            } else {
                                projectIdJarFilePaths.put(projectId,new ArrayList<String>(){{add(file.getAbsolutePath());}});
                            }
                        }
                    }
                }

                try {
                    if (!CollectionUtils.isEmpty(projectIdJarFilePaths)) {
                        for (Map.Entry<String, List<String>> entry : projectIdJarFilePaths.entrySet()) {
                            // 加载新的jar
                            loadJar(entry.getKey(),entry.getValue().toArray(new String[0]));
                        }
                    }
                } catch (Exception e) {
                    throw new RuntimeException("加载jar异常", e);
                }
            }
        }

    }

    public static HashMap<String, Object> run(IAILModel model, Object[] paramsValueArray, String projectId) throws Exception {
        if (model == null) {
            throw new RuntimeException("模型文件不能为空!");
        } else {
            // 上传时校验文件名和class的开头是相同的,保证className唯一,可以用 projectId_className 作为唯一key
            String classCacheKey = projectId + "_" + model.getClassName();
            String methodParams = Arrays.stream(model.getParamsArray()).map(e -> e.getName()).collect(Collectors.joining(","));
            String classMethodCacheKey = classCacheKey + "." + model.getMethodName() + "(" + methodParams + ")";
            // 因为一个类下可能有多多个方法,所以这里classCacheKey 和 classMethodCacheKey 要分开判断
            if (classCache.containsKey(classCacheKey)) {
                if (classMethodCache.containsKey(classMethodCacheKey)) {
                    return (HashMap) classMethodCache.get(classMethodCacheKey).invoke(classCache.get(classCacheKey), paramsValueArray);
                } else {
                    // 运行过这个类的其他方法,类从缓存中取,新建方法
                    Object o = classCache.get(classCacheKey);
                    Class<?>[] paramsArray = new Class[2];
                    paramsArray[0] = int.class;
                    paramsArray[1] = Object[].class;
                    Method method = o.getClass().getMethod(model.getMethodName(), paramsArray);
                    classMethodCache.put(classMethodCacheKey, method);
                    Object[] objects = (Object[]) method.invoke(o, paramsValueArray);
                    HashMap<String, Object> map = convertStructToMap((MWStructArray) objects[0]);
                    return map;
                }

            } else {
                URLClassLoader classLoader = getClassLoader(projectId);
                if (null == classLoader) {
                    throw new RuntimeException("matlab未发布,classLoader为null");
                }
                Class<?> clazz = classLoader.loadClass(model.getClassName());
                Object o = clazz.newInstance();
                Class<?>[] paramsArray = new Class[2];
                paramsArray[0] = int.class;
                paramsArray[1] = Object[].class;
                Method method = clazz.getMethod(model.getMethodName(), paramsArray);
                classCache.put(classCacheKey, o);
                classMethodCache.put(classMethodCacheKey, method);
                Object[] objects = (Object[]) method.invoke(o, paramsValueArray);
                HashMap<String, Object> map = convertStructToMap((MWStructArray) objects[0]);
                return map;
            }
        }
    }

    public static HashMap<String, Object> convertStructToMap(MWStructArray struct) {
        HashMap<String, Object> map;
        try {
            map = new HashMap<>();
            String[] fieldNames = struct.fieldNames();
            int numElements = struct.numberOfElements();
            for (int i = 1; i <= numElements; i++) {
                for (String fieldName : fieldNames) {
                    MWArray value = struct.getField(fieldName, i);
                    Object javaValue = convertMWArrayToJavaObject(value);
                    map.put(fieldName, javaValue);
                }
            }
        } finally {
            struct.dispose();
        }
        return map;
    }

    private static Object convertMWArrayToJavaObject(MWArray mwArray) {
        try {
            if (mwArray instanceof MWNumericArray) {
                MWNumericArray numArray = (MWNumericArray) mwArray;
                if (numArray.numberOfElements() == 1) {
                    MWClassID mwClassID = numArray.classID();
                    if (mwClassID.equals(MWClassID.DOUBLE)) {
                        return numArray.getDouble();
                    } else if (mwClassID.equals(MWClassID.SINGLE)) {
                        return numArray.getFloat();
                    } else if (mwClassID.equals(MWClassID.INT8) || mwClassID.equals(MWClassID.UINT8)) {
                        return numArray.getByte();
                    } else if (mwClassID.equals(MWClassID.INT16) || mwClassID.equals(MWClassID.UINT16)) {
                        return numArray.getShort();
                    } else if (mwClassID.equals(MWClassID.INT32) || mwClassID.equals(MWClassID.UINT32)) {
                        return numArray.getInt();
                    } else if (mwClassID.equals(MWClassID.INT64) || mwClassID.equals(MWClassID.UINT64)) {
                        return numArray.getLong();
                    } else if (mwClassID.equals(MWClassID.LOGICAL)) {
                        return numArray.getByte();
                    }
                    return null;
                } else {
                    MWClassID mwClassID = numArray.classID();
                    if (mwClassID.equals(MWClassID.DOUBLE)) {
                        return numArray.toDoubleArray();
                    } else if (mwClassID.equals(MWClassID.SINGLE)) {
                        return numArray.toFloatArray();
                    } else if (mwClassID.equals(MWClassID.INT8)) {
                        return numArray.toByteArray();
                    } else if (mwClassID.equals(MWClassID.INT16) || mwClassID.equals(MWClassID.UINT16)) {
                        return numArray.toShortArray();
                    } else if (mwClassID.equals(MWClassID.INT32) || mwClassID.equals(MWClassID.UINT32)) {
                        return numArray.toIntArray();
                    } else if (mwClassID.equals(MWClassID.INT64) || mwClassID.equals(MWClassID.UINT64)) {
                        return numArray.toLongArray();
                    } else if (mwClassID.equals(MWClassID.LOGICAL)) {
                        return numArray.toByteArray();
                    }
                    return null;
                }
            } else if (mwArray instanceof MWCharArray) {
                MWCharArray stringArray = (MWCharArray) mwArray;
                if (stringArray.numberOfElements() == 1) {
                    return stringArray.getChar(1);
                } else {
                    // 不支持string,string都是用char[]返回,所以这里将char[]转为string
                    int[] dimensions = stringArray.getDimensions();
                    if (dimensions[0] == 1) {
                        //string
                        return stringArray.toString();
                    } else {
                        //String[] 暂时只考虑一维string
                        char[][] array = (char[][]) stringArray.toArray();
                        String[] list = new String[dimensions[0]];
                        for (int i = 0; i < dimensions[0]; i++) {
                            list[i] = (String.valueOf(array[i]));
                        }
                        return list;
                    }
                }
            } else if (mwArray instanceof MWLogicalArray) {
                MWLogicalArray logicalArray = (MWLogicalArray) mwArray;
                if (logicalArray.numberOfElements() == 1) {
                    return logicalArray.getBoolean(1);
                } else {
                    int[] dimensions = logicalArray.getDimensions();
                    return logicalArray.toArray();
                }
            } else if (mwArray instanceof MWStructArray) {
                MWStructArray structArray = (MWStructArray) mwArray;
                return convertStructToMap(structArray);
            }
        } finally {
            mwArray.dispose();
        }
        return null;
    }

    public static HashMap<String, Object> handleModelSettings(List<MlModelMethodSettingDTO> modelSettings) {
        HashMap<String, Object> resultMap = null;
        try {
            resultMap = new HashMap<>(modelSettings.size());
            for (MlModelMethodSettingDTO modelSetting : modelSettings) {
                switch (modelSetting.getValueType()) {
                    case "int":
                        resultMap.put(modelSetting.getSettingKey(), Integer.valueOf(modelSetting.getSettingValue()));
                        break;
                    case "string":
                        resultMap.put(modelSetting.getSettingKey(), modelSetting.getSettingValue());
                        break;
                    case "decimal":
                        resultMap.put(modelSetting.getSettingKey(), Double.valueOf(modelSetting.getSettingValue()));
                        break;
                    case "decimalArray":
                        JSONArray jsonArray = JSON.parseArray(modelSetting.getSettingValue());
                        double[] doubles = new double[jsonArray.size()];
                        for (int i = 0; i < jsonArray.size(); i++) {
                            doubles[i] = Double.valueOf(String.valueOf(jsonArray.get(i)));
                        }
                        resultMap.put(modelSetting.getSettingKey(), doubles);
                        break;
                }
            }
        } catch (NumberFormatException e) {
            throw new RuntimeException("模型参数有误,请检查!!!");
        }
        return resultMap;
    }

    public static MWStructArray convertMapToStruct(Map<String, Object> map) {
        String[] fieldNames = map.keySet().toArray(new String[0]);
        MWStructArray struct = new MWStructArray(1, 1, fieldNames);
        for (String key : map.keySet()) {
            Object value = map.get(key);
            if (value instanceof Number) {
                struct.set(key, 1, ((Number) value).doubleValue());
            } else if (value instanceof String) {
                struct.set(key, 1, value);
            }
            // 可以根据需要添加更多类型的处理
        }
        return struct;
    }
}