潘志宝
2024-08-26 368beb362d7ffb017174d7d79a16032d0647776f
提交 | 用户 | 时间
e7c126 1 package com.iailab.module.system.controller.admin.oauth2;
H 2
3 import cn.hutool.core.collection.ListUtil;
4 import cn.hutool.core.date.LocalDateTimeUtil;
5 import cn.hutool.core.map.MapUtil;
6 import com.iailab.framework.common.core.KeyValue;
7 import com.iailab.framework.common.enums.UserTypeEnum;
8 import com.iailab.framework.common.exception.ErrorCode;
9 import com.iailab.framework.common.pojo.CommonResult;
10 import com.iailab.framework.common.util.collection.SetUtils;
11 import com.iailab.framework.common.util.object.ObjectUtils;
12 import com.iailab.framework.test.core.ut.BaseMockitoUnitTest;
13 import com.iailab.module.system.controller.admin.oauth2.vo.open.OAuth2OpenAccessTokenRespVO;
14 import com.iailab.module.system.controller.admin.oauth2.vo.open.OAuth2OpenAuthorizeInfoRespVO;
15 import com.iailab.module.system.controller.admin.oauth2.vo.open.OAuth2OpenCheckTokenRespVO;
16 import com.iailab.module.system.dal.dataobject.oauth2.OAuth2AccessTokenDO;
17 import com.iailab.module.system.dal.dataobject.oauth2.OAuth2ApproveDO;
18 import com.iailab.module.system.dal.dataobject.oauth2.OAuth2ClientDO;
19 import com.iailab.module.system.enums.oauth2.OAuth2GrantTypeEnum;
20 import com.iailab.module.system.service.oauth2.OAuth2ApproveService;
21 import com.iailab.module.system.service.oauth2.OAuth2ClientService;
22 import com.iailab.module.system.service.oauth2.OAuth2GrantService;
23 import com.iailab.module.system.service.oauth2.OAuth2TokenService;
24 import org.assertj.core.util.Lists;
25 import org.junit.jupiter.api.Test;
26 import org.mockito.InjectMocks;
27 import org.mockito.Mock;
28
29 import javax.servlet.http.HttpServletRequest;
30 import java.time.LocalDateTime;
31 import java.time.temporal.ChronoUnit;
32 import java.util.ArrayList;
33 import java.util.LinkedHashMap;
34 import java.util.List;
35
36 import static com.iailab.framework.common.util.collection.SetUtils.asSet;
37 import static com.iailab.framework.test.core.util.AssertUtils.assertPojoEquals;
38 import static com.iailab.framework.test.core.util.AssertUtils.assertServiceException;
39 import static com.iailab.framework.test.core.util.RandomUtils.randomPojo;
40 import static com.iailab.framework.test.core.util.RandomUtils.randomString;
41 import static java.util.Arrays.asList;
42 import static org.hamcrest.CoreMatchers.anyOf;
43 import static org.hamcrest.CoreMatchers.is;
44 import static org.hamcrest.MatcherAssert.assertThat;
45 import static org.junit.jupiter.api.Assertions.*;
46 import static org.mockito.ArgumentMatchers.eq;
47 import static org.mockito.ArgumentMatchers.isNull;
48 import static org.mockito.Mockito.mock;
49 import static org.mockito.Mockito.when;
50
51 /**
52  * {@link OAuth2OpenController} 的单元测试
53  *
54  * @author iailab
55  */
56 public class OAuth2OpenControllerTest extends BaseMockitoUnitTest {
57
58     @InjectMocks
59     private OAuth2OpenController oauth2OpenController;
60
61     @Mock
62     private OAuth2GrantService oauth2GrantService;
63     @Mock
64     private OAuth2ClientService oauth2ClientService;
65     @Mock
66     private OAuth2ApproveService oauth2ApproveService;
67     @Mock
68     private OAuth2TokenService oauth2TokenService;
69
70     @Test
71     public void testPostAccessToken_authorizationCode() {
72         // 准备参数
73         String granType = OAuth2GrantTypeEnum.AUTHORIZATION_CODE.getGrantType();
74         String code = randomString();
75         String redirectUri = randomString();
76         String state = randomString();
77         HttpServletRequest request = mockRequest("test_client_id", "test_client_secret");
78         // mock 方法(client)
79         OAuth2ClientDO client = randomPojo(OAuth2ClientDO.class).setClientId("test_client_id");
80         when(oauth2ClientService.validOAuthClientFromCache(eq("test_client_id"), eq("test_client_secret"), eq(granType), eq(new ArrayList<>()), eq(redirectUri))).thenReturn(client);
81
82         // mock 方法(访问令牌)
83         OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class)
84                 .setExpiresTime(LocalDateTimeUtil.offset(LocalDateTime.now(), 30000L, ChronoUnit.MILLIS));
85         when(oauth2GrantService.grantAuthorizationCodeForAccessToken(eq("test_client_id"),
86                 eq(code), eq(redirectUri), eq(state))).thenReturn(accessTokenDO);
87
88         // 调用
89         CommonResult<OAuth2OpenAccessTokenRespVO> result = oauth2OpenController.postAccessToken(request, granType,
90                 code, redirectUri, state, null, null, null, null);
91         // 断言
92         assertEquals(0, result.getCode());
93         assertPojoEquals(accessTokenDO, result.getData());
94         assertTrue(ObjectUtils.equalsAny(result.getData().getExpiresIn(), 29L, 30L));  // 执行过程会过去几毫秒
95     }
96
97     @Test
98     public void testPostAccessToken_password() {
99         // 准备参数
100         String granType = OAuth2GrantTypeEnum.PASSWORD.getGrantType();
101         String username = randomString();
102         String password = randomString();
103         String scope = "write read";
104         HttpServletRequest request = mockRequest("test_client_id", "test_client_secret");
105         // mock 方法(client)
106         OAuth2ClientDO client = randomPojo(OAuth2ClientDO.class).setClientId("test_client_id");
107         when(oauth2ClientService.validOAuthClientFromCache(eq("test_client_id"), eq("test_client_secret"),
108                 eq(granType), eq(Lists.newArrayList("write", "read")), isNull())).thenReturn(client);
109
110         // mock 方法(访问令牌)
111         OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class)
112                 .setExpiresTime(LocalDateTimeUtil.offset(LocalDateTime.now(), 30000L, ChronoUnit.MILLIS));
113         when(oauth2GrantService.grantPassword(eq(username), eq(password), eq("test_client_id"),
114                 eq(Lists.newArrayList("write", "read")))).thenReturn(accessTokenDO);
115
116         // 调用
117         CommonResult<OAuth2OpenAccessTokenRespVO> result = oauth2OpenController.postAccessToken(request, granType,
118                 null, null, null, username, password, scope, null);
119         // 断言
120         assertEquals(0, result.getCode());
121         assertPojoEquals(accessTokenDO, result.getData());
122         assertTrue(ObjectUtils.equalsAny(result.getData().getExpiresIn(), 29L, 30L));  // 执行过程会过去几毫秒
123     }
124
125     @Test
126     public void testPostAccessToken_refreshToken() {
127         // 准备参数
128         String granType = OAuth2GrantTypeEnum.REFRESH_TOKEN.getGrantType();
129         String refreshToken = randomString();
130         String password = randomString();
131         HttpServletRequest request = mockRequest("test_client_id", "test_client_secret");
132         // mock 方法(client)
133         OAuth2ClientDO client = randomPojo(OAuth2ClientDO.class).setClientId("test_client_id");
134         when(oauth2ClientService.validOAuthClientFromCache(eq("test_client_id"), eq("test_client_secret"),
135                 eq(granType), eq(Lists.newArrayList()), isNull())).thenReturn(client);
136
137         // mock 方法(访问令牌)
138         OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class)
139                 .setExpiresTime(LocalDateTimeUtil.offset(LocalDateTime.now(), 30000L, ChronoUnit.MILLIS));
140         when(oauth2GrantService.grantRefreshToken(eq(refreshToken), eq("test_client_id"))).thenReturn(accessTokenDO);
141
142         // 调用
143         CommonResult<OAuth2OpenAccessTokenRespVO> result = oauth2OpenController.postAccessToken(request, granType,
144                 null, null, null, null, password, null, refreshToken);
145         // 断言
146         assertEquals(0, result.getCode());
147         assertPojoEquals(accessTokenDO, result.getData());
148         assertTrue(ObjectUtils.equalsAny(result.getData().getExpiresIn(), 29L, 30L));  // 执行过程会过去几毫秒
149     }
150
151     @Test
152     public void testPostAccessToken_implicit() {
153         // 调用,并断言
154         assertServiceException(() -> oauth2OpenController.postAccessToken(null,
155                         OAuth2GrantTypeEnum.IMPLICIT.getGrantType(), null, null, null,
156                         null, null, null, null),
157                 new ErrorCode(400, "Token 接口不支持 implicit 授权模式"));
158     }
159
160     @Test
161     public void testRevokeToken() {
162         // 准备参数
163         HttpServletRequest request = mockRequest("demo_client_id", "demo_client_secret");
164         String token = randomString();
165         // mock 方法(client)
166         OAuth2ClientDO client = randomPojo(OAuth2ClientDO.class).setClientId("demo_client_id");
167         when(oauth2ClientService.validOAuthClientFromCache(eq("demo_client_id"),
168                 eq("demo_client_secret"), isNull(), isNull(), isNull())).thenReturn(client);
169         // mock 方法(移除)
170         when(oauth2GrantService.revokeToken(eq("demo_client_id"), eq(token))).thenReturn(true);
171
172         // 调用
173         CommonResult<Boolean> result = oauth2OpenController.revokeToken(request, token);
174         // 断言
175         assertEquals(0, result.getCode());
176         assertTrue(result.getData());
177     }
178
179     @Test
180     public void testCheckToken() {
181         // 准备参数
182         HttpServletRequest request = mockRequest("demo_client_id", "demo_client_secret");
183         String token = randomString();
184         // mock 方法
185         OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class).setUserType(UserTypeEnum.ADMIN.getValue()).setExpiresTime(LocalDateTimeUtil.of(1653485731195L));
186         when(oauth2TokenService.checkAccessToken(eq(token))).thenReturn(accessTokenDO);
187
188         // 调用
189         CommonResult<OAuth2OpenCheckTokenRespVO> result = oauth2OpenController.checkToken(request, token);
190         // 断言
191         assertEquals(0, result.getCode());
192         assertPojoEquals(accessTokenDO, result.getData());
193         assertEquals(1653485731L, result.getData().getExp()); // 执行过程会过去几毫秒
194     }
195
196     @Test
197     public void testAuthorize() {
198         // 准备参数
199         String clientId = randomString();
200         // mock 方法(client)
201         OAuth2ClientDO client = randomPojo(OAuth2ClientDO.class).setClientId("demo_client_id").setScopes(ListUtil.toList("read", "write", "all"));
202         when(oauth2ClientService.validOAuthClientFromCache(eq(clientId))).thenReturn(client);
203         // mock 方法(approve)
204         List<OAuth2ApproveDO> approves = asList(
205                 randomPojo(OAuth2ApproveDO.class).setScope("read").setApproved(true),
206                 randomPojo(OAuth2ApproveDO.class).setScope("write").setApproved(false));
207         when(oauth2ApproveService.getApproveList(isNull(), eq(UserTypeEnum.ADMIN.getValue()), eq(clientId))).thenReturn(approves);
208
209         // 调用
210         CommonResult<OAuth2OpenAuthorizeInfoRespVO> result = oauth2OpenController.authorize(clientId);
211         // 断言
212         assertEquals(0, result.getCode());
213         assertPojoEquals(client, result.getData().getClient());
214         assertEquals(new KeyValue<>("read", true), result.getData().getScopes().get(0));
215         assertEquals(new KeyValue<>("write", false), result.getData().getScopes().get(1));
216         assertEquals(new KeyValue<>("all", false), result.getData().getScopes().get(2));
217     }
218
219     @Test
220     public void testApproveOrDeny_grantTypeError() {
221         // 调用,并断言
222         assertServiceException(() -> oauth2OpenController.approveOrDeny(randomString(), null,
223                         null, null, null, null),
224                 new ErrorCode(400, "response_type 参数值只允许 code 和 token"));
225     }
226
227     @Test // autoApprove = true,但是不通过
228     public void testApproveOrDeny_autoApproveNo() {
229         // 准备参数
230         String responseType = "code";
231         String clientId = randomString();
232         String scope = "{\"read\": true, \"write\": false}";
233         String redirectUri = randomString();
234         String state = randomString();
235         // mock 方法
236         OAuth2ClientDO client = randomPojo(OAuth2ClientDO.class);
237         when(oauth2ClientService.validOAuthClientFromCache(eq(clientId), isNull(), eq("authorization_code"),
238                 eq(asSet("read", "write")), eq(redirectUri))).thenReturn(client);
239
240         // 调用
241         CommonResult<String> result = oauth2OpenController.approveOrDeny(responseType, clientId,
242                 scope, redirectUri, true, state);
243         // 断言
244         assertEquals(0, result.getCode());
245         assertNull(result.getData());
246     }
247
248     @Test // autoApprove = false,但是不通过
249     public void testApproveOrDeny_ApproveNo() {
250         // 准备参数
251         String responseType = "token";
252         String clientId = randomString();
253         String scope = "{\"read\": true, \"write\": false}";
254         String redirectUri = "https://www.baidu.com";
255         String state = "test";
256         // mock 方法
257         OAuth2ClientDO client = randomPojo(OAuth2ClientDO.class);
258         when(oauth2ClientService.validOAuthClientFromCache(eq(clientId), isNull(), eq("implicit"),
259                 eq(asSet("read", "write")), eq(redirectUri))).thenReturn(client);
260
261         // 调用
262         CommonResult<String> result = oauth2OpenController.approveOrDeny(responseType, clientId,
263                 scope, redirectUri, false, state);
264         // 断言
265         assertEquals(0, result.getCode());
266         assertEquals("https://www.baidu.com#error=access_denied&error_description=User%20denied%20access&state=test", result.getData());
267     }
268
269     @Test // autoApprove = true,通过 + token
270     public void testApproveOrDeny_autoApproveWithToken() {
271         // 准备参数
272         String responseType = "token";
273         String clientId = randomString();
274         String scope = "{\"read\": true, \"write\": false}";
275         String redirectUri = "https://www.baidu.com";
276         String state = "test";
277         // mock 方法(client)
278         OAuth2ClientDO client = randomPojo(OAuth2ClientDO.class).setClientId(clientId).setAdditionalInformation(null);
279         when(oauth2ClientService.validOAuthClientFromCache(eq(clientId), isNull(), eq("implicit"),
280                 eq(asSet("read", "write")), eq(redirectUri))).thenReturn(client);
281         // mock 方法(场景一)
282         when(oauth2ApproveService.checkForPreApproval(isNull(), eq(UserTypeEnum.ADMIN.getValue()),
283                 eq(clientId), eq(SetUtils.asSet("read", "write")))).thenReturn(true);
284         // mock 方法(访问令牌)
285         OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class)
286                 .setAccessToken("test_access_token").setExpiresTime(LocalDateTimeUtil.offset(LocalDateTime.now(), 30010L, ChronoUnit.MILLIS));
287         when(oauth2GrantService.grantImplicit(isNull(), eq(UserTypeEnum.ADMIN.getValue()),
288                 eq(clientId), eq(ListUtil.toList("read")))).thenReturn(accessTokenDO);
289
290         // 调用
291         CommonResult<String> result = oauth2OpenController.approveOrDeny(responseType, clientId,
292                 scope, redirectUri, true, state);
293         // 断言
294         assertEquals(0, result.getCode());
295         assertThat(result.getData(), anyOf( // 29 和 30 都有一定概率,主要是时间计算
296                 is("https://www.baidu.com#access_token=test_access_token&token_type=bearer&state=test&expires_in=29&scope=read"),
297                 is("https://www.baidu.com#access_token=test_access_token&token_type=bearer&state=test&expires_in=30&scope=read")
298         ));
299     }
300
301     @Test // autoApprove = false,通过 + code
302     public void testApproveOrDeny_approveWithCode() {
303         // 准备参数
304         String responseType = "code";
305         String clientId = randomString();
306         String scope = "{\"read\": true, \"write\": false}";
307         String redirectUri = "https://www.baidu.com";
308         String state = "test";
309         // mock 方法(client)
310         OAuth2ClientDO client = randomPojo(OAuth2ClientDO.class).setClientId(clientId).setAdditionalInformation(null);
311         when(oauth2ClientService.validOAuthClientFromCache(eq(clientId), isNull(), eq("authorization_code"),
312                 eq(asSet("read", "write")), eq(redirectUri))).thenReturn(client);
313         // mock 方法(场景二)
314         when(oauth2ApproveService.updateAfterApproval(isNull(), eq(UserTypeEnum.ADMIN.getValue()), eq(clientId),
315                 eq(MapUtil.builder(new LinkedHashMap<String, Boolean>()).put("read", true).put("write", false).build())))
316                 .thenReturn(true);
317         // mock 方法(访问令牌)
318         String authorizationCode = "test_code";
319         when(oauth2GrantService.grantAuthorizationCodeForCode(isNull(), eq(UserTypeEnum.ADMIN.getValue()),
320                 eq(clientId), eq(ListUtil.toList("read")), eq(redirectUri), eq(state))).thenReturn(authorizationCode);
321
322         // 调用
323         CommonResult<String> result = oauth2OpenController.approveOrDeny(responseType, clientId,
324                 scope, redirectUri, false, state);
325         // 断言
326         assertEquals(0, result.getCode());
327         assertEquals("https://www.baidu.com?code=test_code&state=test", result.getData());
328     }
329
330     private HttpServletRequest mockRequest(String clientId, String secret) {
331         HttpServletRequest request = mock(HttpServletRequest.class);
332         when(request.getParameter(eq("client_id"))).thenReturn(clientId);
333         when(request.getParameter(eq("client_secret"))).thenReturn(secret);
334         return request;
335     }
336
337 }