Jay
2024-11-12 3d9106399d9a2b9c8ba7d2dea621f54fd71d2ca7
提交 | 用户 | 时间
449017 1 package com.iailab.module.model.mpk.common.utils;
D 2
3 import java.util.HashMap;
4
5
6 /**
7  * @Auther: Forrest
8  * @Date: 2020/6/8 14:05
9  * @Description:
10  */
11 public class AlgsUtils {
12     private HashMap<String, Object> model = new HashMap<String, Object>();
13
14     public HashMap<String, Object> createPredictHashmap(HashMap<String, Object> models) {
15         if ((models.containsKey("iail/mdk/model"))) {
16             if (((String) ((HashMap) models.get("iail/mdk/model")).get("param1")).isEmpty()) {
17                 String aaa = "error";
18                 model.put("param1", aaa);
19             } else {
20                 String model_train = (String) ((HashMap) models.get("iail/mdk/model")).get("param1");
21                 model.put("param1", model_train);
22             }
23         } else {
24             model = models;
25         }
26         return model;
27     }
28
29     public HashMap<String, Object> createPredictHashmapplus(HashMap<String, Object> models) {
30         if (models != null && models.containsKey("models")) {
31             if (((String) ((HashMap) models.get("models")).get("paramFile")).isEmpty()) {
32                 String aaa = "error";
33                 model.put("param1", aaa);
34             } else {
35                 String model_train = (String) ((HashMap) models.get("models")).get("paramFile");
36                 model.put("paramFile", model_train);
37                 if (((HashMap) models.get("models")).containsKey("dim")) {
38                     Object dim = ((HashMap) models.get("models")).get("dim");
39                     model.put("dim", dim);
40                 }
41             }
42         } else {
43             model = models;
44         }
45         return model;
46     }
47
48     private HashMap<String, Object> eval_pre = new HashMap<String, Object>();
49
50     /**
51      * 对返回码进行转换
52      *
53      * @param models
54      * @return
55      */
56     public int reverseResultCode(HashMap<String, Object> models) {
57         if ((models.containsKey("result_code"))) {
58             return Integer.parseInt((String) models.get("result_code"));
59         }
60         return -2;
61     }
62
63     /**
64      * 对评价指标进行转换
65      *
66      * @param models
67      * @return
68      */
69     public HashMap<String, Object> reverseEval(HashMap<String, Object> models) {
70         if ((models.containsKey("eval"))) {
71             if (((HashMap) models.get("eval")).containsKey("MAE")) {
72                 double MAE = Double.parseDouble((String) ((HashMap) models.get("eval")).get("MAE"));
73                 eval_pre.put("MAE", MAE);
74             }
75             if (((HashMap) models.get("eval")).containsKey("MAPE")) {
76                 double MAPE = Double.parseDouble((String) ((HashMap) models.get("eval")).get("MAPE"));
77                 eval_pre.put("MAPE", MAPE);
78             }
79             if (((HashMap) models.get("eval")).containsKey("RMSE")) {
80                 double MAE = Double.parseDouble((String) ((HashMap) models.get("eval")).get("RMSE"));
81                 eval_pre.put("RMSE", MAE);
82             }
83         }
84
85         return eval_pre;
86     }
87
88     /**
89      * 对models里面的参数进行转换
90      */
91     private HashMap<String, Object> train_result_models = new HashMap<String, Object>();
92
93     public HashMap<String, Object> reverseModels(HashMap<String, Object> train_result) {
94         if (train_result.containsKey("models")) {
95             train_result_models = (HashMap) train_result.get("models");
96             if (((HashMap) train_result.get("models")).containsKey("dim")) {
97                 double dim = Double.parseDouble((String) ((HashMap) train_result.get("models")).get("dim"));
98                 train_result_models.put("dim", dim);
99             }
100             train_result.put("models", train_result_models);
101         }
102         return train_result;
103     }
104
105
106     /**
107      * 获取二维数组行列数
108      *
109      * @param arr
110      * @return
111      */
112     public int[] getColAndRow(double[][] arr) {
113         int row = arr.length;
114         int col = arr[0].length;
115         int[] result = new int[2];
116         result[0] = row;
117         result[1] = col;
118         return result;
119     }
120
121     /**
122      * 两个二维数组进行合并
123      *
124      * @param data
125      * @param refs
126      * @return
127      */
128     public double[][] getMathergeArr(double[][] data, double[][] refs) {
129
130         int[] dataRowAndCol = getColAndRow(data);
131         int rowData = dataRowAndCol[0];
132         int colData = dataRowAndCol[1];
133
134         int[] refsRowAndCol = getColAndRow(refs);
135         int rowrefs = refsRowAndCol[0];
136         int colrefs = refsRowAndCol[1];
137
138         double[][] newData = new double[rowData + rowrefs][colData];
139         for (int i = 0; i < rowData; i++) {
140             for (int j = 0; j < colData; j++) {
141                 newData[i][j] = data[i][j];
142             }
143         }
144
145         for (int i = 0; i < rowrefs; i++) {
146             for (int j = 0; j < colrefs; j++) {
147                 newData[i + rowData][j] = refs[i][j];
148             }
149         }
150         return newData;
151     }
152
153     /**
154      * 对训练方法进行处理,实现评价指标的转换
155      */
156     public HashMap<String, Object> trainUtil(HashMap<String, Object> train_result, HashMap<String, Object> eval, String time) {
157         if (train_result.containsKey("eval")) {
158             eval = (HashMap<String, Object>) train_result.get("eval");
159             eval.put("time", time);
160             train_result.put("eval", eval);
161         }
162         train_result.put("result_code", reverseResultCode(train_result));
163         return train_result;
164     }
165
166     /**
167      * 对预测方法进行处理
168      */
169 //    public HashMap<String,Object> predictUtil(){
170 //
171 //    }
172 }