潘志宝
2024-09-23 0a2f6f78683ba1c4e07f1359c1e7bf105a4bd507
提交 | 用户 | 时间
b425df 1 package com.iailab.module.model.common.xss;
2
3 import org.apache.commons.io.IOUtils;
4 import org.apache.commons.lang3.StringUtils;
5 import org.springframework.http.HttpHeaders;
6 import org.springframework.http.MediaType;
7
8 import javax.servlet.ReadListener;
9 import javax.servlet.ServletInputStream;
10 import javax.servlet.http.HttpServletRequest;
11 import javax.servlet.http.HttpServletRequestWrapper;
12 import java.io.ByteArrayInputStream;
13 import java.io.IOException;
14 import java.util.LinkedHashMap;
15 import java.util.Map;
16
17 /**
18  * XSS过滤处理
19  *
20  * @author Mark sunlightcs@gmail.com
21  */
22 public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {
23     //没被包装过的HttpServletRequest(特殊场景,需要自己过滤)
24     HttpServletRequest orgRequest;
25     //html过滤
26     private final static HTMLFilter htmlFilter = new HTMLFilter();
27
28     public XssHttpServletRequestWrapper(HttpServletRequest request) {
29         super(request);
30         orgRequest = request;
31     }
32
33     @Override
34     public ServletInputStream getInputStream() throws IOException {
35         //非json类型,直接返回
36         if(!MediaType.APPLICATION_JSON_VALUE.equalsIgnoreCase(super.getHeader(HttpHeaders.CONTENT_TYPE))){
37             return super.getInputStream();
38         }
39
40         //为空,直接返回
41         String json = IOUtils.toString(super.getInputStream(), "utf-8");
42         if (StringUtils.isBlank(json)) {
43             return super.getInputStream();
44         }
45
46         //xss过滤
47         json = xssEncode(json);
48         final ByteArrayInputStream bis = new ByteArrayInputStream(json.getBytes("utf-8"));
49         return new ServletInputStream() {
50             @Override
51             public boolean isFinished() {
52                 return true;
53             }
54
55             @Override
56             public boolean isReady() {
57                 return true;
58             }
59
60             @Override
61             public void setReadListener(ReadListener readListener) {
62
63             }
64
65             @Override
66             public int read() throws IOException {
67                 return bis.read();
68             }
69         };
70     }
71
72     @Override
73     public String getParameter(String name) {
74         String value = super.getParameter(xssEncode(name));
75         if (StringUtils.isNotBlank(value)) {
76             value = xssEncode(value);
77         }
78         return value;
79     }
80
81     @Override
82     public String[] getParameterValues(String name) {
83         String[] parameters = super.getParameterValues(name);
84         if (parameters == null || parameters.length == 0) {
85             return null;
86         }
87
88         for (int i = 0; i < parameters.length; i++) {
89             parameters[i] = xssEncode(parameters[i]);
90         }
91         return parameters;
92     }
93
94     @Override
95     public Map<String,String[]> getParameterMap() {
96         Map<String,String[]> map = new LinkedHashMap<>();
97         Map<String,String[]> parameters = super.getParameterMap();
98         for (String key : parameters.keySet()) {
99             String[] values = parameters.get(key);
100             for (int i = 0; i < values.length; i++) {
101                 values[i] = xssEncode(values[i]);
102             }
103             map.put(key, values);
104         }
105         return map;
106     }
107
108     @Override
109     public String getHeader(String name) {
110         String value = super.getHeader(xssEncode(name));
111         if (StringUtils.isNotBlank(value)) {
112             value = xssEncode(value);
113         }
114         return value;
115     }
116
117     private String xssEncode(String input) {
118         return htmlFilter.filter(input);
119     }
120
121     /**
122      * 获取最原始的request
123      */
124     public HttpServletRequest getOrgRequest() {
125         return orgRequest;
126     }
127
128     /**
129      * 获取最原始的request
130      */
131     public static HttpServletRequest getOrgRequest(HttpServletRequest request) {
132         if (request instanceof XssHttpServletRequestWrapper) {
133             return ((XssHttpServletRequestWrapper) request).getOrgRequest();
134         }
135
136         return request;
137     }
138
139 }