潘志宝
2024-08-21 c39abccd937de093fc067abffac5f66b758bc97b
提交 | 用户 | 时间
e7c126 1 package com.iailab.module.system.service.oauth2;
H 2
3 import cn.hutool.core.date.LocalDateTimeUtil;
4 import com.iailab.framework.common.enums.UserTypeEnum;
5 import com.iailab.framework.common.exception.ErrorCode;
6 import com.iailab.framework.common.pojo.PageResult;
7 import com.iailab.framework.common.util.date.DateUtils;
8 import com.iailab.framework.tenant.core.context.TenantContextHolder;
9 import com.iailab.framework.test.core.ut.BaseDbAndRedisUnitTest;
10 import com.iailab.module.system.controller.admin.oauth2.vo.token.OAuth2AccessTokenPageReqVO;
11 import com.iailab.module.system.dal.dataobject.oauth2.OAuth2AccessTokenDO;
12 import com.iailab.module.system.dal.dataobject.oauth2.OAuth2ClientDO;
13 import com.iailab.module.system.dal.dataobject.oauth2.OAuth2RefreshTokenDO;
14 import com.iailab.module.system.dal.dataobject.user.AdminUserDO;
15 import com.iailab.module.system.dal.mysql.oauth2.OAuth2AccessTokenMapper;
16 import com.iailab.module.system.dal.mysql.oauth2.OAuth2RefreshTokenMapper;
17 import com.iailab.module.system.dal.redis.oauth2.OAuth2AccessTokenRedisDAO;
18 import com.iailab.module.system.service.user.AdminUserService;
19 import org.assertj.core.util.Lists;
20 import org.junit.jupiter.api.Test;
21 import org.springframework.boot.test.mock.mockito.MockBean;
22 import org.springframework.context.annotation.Import;
23
24 import javax.annotation.Resource;
25 import java.time.LocalDateTime;
26 import java.util.List;
27
28 import static com.iailab.framework.common.util.object.ObjectUtils.cloneIgnoreId;
29 import static com.iailab.framework.test.core.util.AssertUtils.assertPojoEquals;
30 import static com.iailab.framework.test.core.util.AssertUtils.assertServiceException;
31 import static com.iailab.framework.test.core.util.RandomUtils.*;
32 import static org.junit.jupiter.api.Assertions.*;
33 import static org.mockito.ArgumentMatchers.eq;
34 import static org.mockito.Mockito.when;
35
36 /**
37  * {@link OAuth2TokenServiceImpl} 的单元测试类
38  *
39  * @author iailab
40  */
41 @Import({OAuth2TokenServiceImpl.class, OAuth2AccessTokenRedisDAO.class})
42 public class OAuth2TokenServiceImplTest extends BaseDbAndRedisUnitTest {
43
44     @Resource
45     private OAuth2TokenServiceImpl oauth2TokenService;
46
47     @Resource
48     private OAuth2AccessTokenMapper oauth2AccessTokenMapper;
49     @Resource
50     private OAuth2RefreshTokenMapper oauth2RefreshTokenMapper;
51
52     @Resource
53     private OAuth2AccessTokenRedisDAO oauth2AccessTokenRedisDAO;
54
55     @MockBean
56     private OAuth2ClientService oauth2ClientService;
57     @MockBean
58     private AdminUserService adminUserService;
59
60     @Test
61     public void testCreateAccessToken() {
62         TenantContextHolder.setTenantId(0L);
63         // 准备参数
64         Long userId = randomLongId();
65         Integer userType = UserTypeEnum.ADMIN.getValue();
66         String clientId = randomString();
67         List<String> scopes = Lists.newArrayList("read", "write");
68         // mock 方法
69         OAuth2ClientDO clientDO = randomPojo(OAuth2ClientDO.class).setClientId(clientId)
70                 .setAccessTokenValiditySeconds(30).setRefreshTokenValiditySeconds(60);
71         when(oauth2ClientService.validOAuthClientFromCache(eq(clientId))).thenReturn(clientDO);
72         // mock 数据(用户)
73         AdminUserDO user = randomPojo(AdminUserDO.class);
74         when(adminUserService.getUser(userId)).thenReturn(user);
75
76         // 调用
77         OAuth2AccessTokenDO accessTokenDO = oauth2TokenService.createAccessToken(userId, userType, clientId, scopes);
78         // 断言访问令牌
79         OAuth2AccessTokenDO dbAccessTokenDO = oauth2AccessTokenMapper.selectByAccessToken(accessTokenDO.getAccessToken());
80         assertPojoEquals(accessTokenDO, dbAccessTokenDO, "createTime", "updateTime", "deleted");
81         assertEquals(userId, accessTokenDO.getUserId());
82         assertEquals(userType, accessTokenDO.getUserType());
83         assertEquals(2, accessTokenDO.getUserInfo().size());
84         assertEquals(user.getNickname(), accessTokenDO.getUserInfo().get("nickname"));
85         assertEquals(user.getDeptId().toString(), accessTokenDO.getUserInfo().get("deptId"));
86         assertEquals(clientId, accessTokenDO.getClientId());
87         assertEquals(scopes, accessTokenDO.getScopes());
88         assertFalse(DateUtils.isExpired(accessTokenDO.getExpiresTime()));
89         // 断言访问令牌的缓存
90         OAuth2AccessTokenDO redisAccessTokenDO = oauth2AccessTokenRedisDAO.get(accessTokenDO.getAccessToken());
91         assertPojoEquals(accessTokenDO, redisAccessTokenDO, "createTime", "updateTime", "deleted");
92         // 断言刷新令牌
93         OAuth2RefreshTokenDO refreshTokenDO = oauth2RefreshTokenMapper.selectList().get(0);
94         assertPojoEquals(accessTokenDO, refreshTokenDO, "id", "expiresTime", "createTime", "updateTime", "deleted");
95         assertFalse(DateUtils.isExpired(refreshTokenDO.getExpiresTime()));
96     }
97
98     @Test
99     public void testRefreshAccessToken_null() {
100         // 准备参数
101         String refreshToken = randomString();
102         String clientId = randomString();
103         // mock 方法
104
105         // 调用,并断言
106         assertServiceException(() -> oauth2TokenService.refreshAccessToken(refreshToken, clientId),
107                 new ErrorCode(400, "无效的刷新令牌"));
108     }
109
110     @Test
111     public void testRefreshAccessToken_clientIdError() {
112         // 准备参数
113         String refreshToken = randomString();
114         String clientId = randomString();
115         // mock 方法
116         OAuth2ClientDO clientDO = randomPojo(OAuth2ClientDO.class).setClientId(clientId);
117         when(oauth2ClientService.validOAuthClientFromCache(eq(clientId))).thenReturn(clientDO);
118         // mock 数据(访问令牌)
119         OAuth2RefreshTokenDO refreshTokenDO = randomPojo(OAuth2RefreshTokenDO.class)
120                 .setRefreshToken(refreshToken).setClientId("error");
121         oauth2RefreshTokenMapper.insert(refreshTokenDO);
122
123         // 调用,并断言
124         assertServiceException(() -> oauth2TokenService.refreshAccessToken(refreshToken, clientId),
125                 new ErrorCode(400, "刷新令牌的客户端编号不正确"));
126     }
127
128     @Test
129     public void testRefreshAccessToken_expired() {
130         // 准备参数
131         String refreshToken = randomString();
132         String clientId = randomString();
133         // mock 方法
134         OAuth2ClientDO clientDO = randomPojo(OAuth2ClientDO.class).setClientId(clientId);
135         when(oauth2ClientService.validOAuthClientFromCache(eq(clientId))).thenReturn(clientDO);
136         // mock 数据(访问令牌)
137         OAuth2RefreshTokenDO refreshTokenDO = randomPojo(OAuth2RefreshTokenDO.class)
138                 .setRefreshToken(refreshToken).setClientId(clientId)
139                 .setExpiresTime(LocalDateTime.now().minusDays(1));
140         oauth2RefreshTokenMapper.insert(refreshTokenDO);
141
142         // 调用,并断言
143         assertServiceException(() -> oauth2TokenService.refreshAccessToken(refreshToken, clientId),
144                 new ErrorCode(401, "刷新令牌已过期"));
145         assertEquals(0, oauth2RefreshTokenMapper.selectCount());
146     }
147
148     @Test
149     public void testRefreshAccessToken_success() {
150         TenantContextHolder.setTenantId(0L);
151         // 准备参数
152         String refreshToken = randomString();
153         String clientId = randomString();
154         // mock 方法
155         OAuth2ClientDO clientDO = randomPojo(OAuth2ClientDO.class).setClientId(clientId)
156                 .setAccessTokenValiditySeconds(30);
157         when(oauth2ClientService.validOAuthClientFromCache(eq(clientId))).thenReturn(clientDO);
158         // mock 数据(访问令牌)
159         OAuth2RefreshTokenDO refreshTokenDO = randomPojo(OAuth2RefreshTokenDO.class)
160                 .setRefreshToken(refreshToken).setClientId(clientId)
161                 .setExpiresTime(LocalDateTime.now().plusDays(1))
162                 .setUserType(UserTypeEnum.ADMIN.getValue());
163         oauth2RefreshTokenMapper.insert(refreshTokenDO);
164         // mock 数据(访问令牌)
165         OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class).setRefreshToken(refreshToken)
166                 .setUserType(refreshTokenDO.getUserType());
167         oauth2AccessTokenMapper.insert(accessTokenDO);
168         oauth2AccessTokenRedisDAO.set(accessTokenDO);
169         // mock 数据(用户)
170         AdminUserDO user = randomPojo(AdminUserDO.class);
171         when(adminUserService.getUser(refreshTokenDO.getUserId())).thenReturn(user);
172
173         // 调用
174         OAuth2AccessTokenDO newAccessTokenDO = oauth2TokenService.refreshAccessToken(refreshToken, clientId);
175         // 断言,老的访问令牌被删除
176         assertNull(oauth2AccessTokenMapper.selectByAccessToken(accessTokenDO.getAccessToken()));
177         assertNull(oauth2AccessTokenRedisDAO.get(accessTokenDO.getAccessToken()));
178         // 断言,新的访问令牌
179         OAuth2AccessTokenDO dbAccessTokenDO = oauth2AccessTokenMapper.selectByAccessToken(newAccessTokenDO.getAccessToken());
180         assertPojoEquals(newAccessTokenDO, dbAccessTokenDO, "createTime", "updateTime", "deleted");
181         assertPojoEquals(newAccessTokenDO, refreshTokenDO, "id", "expiresTime", "createTime", "updateTime", "deleted",
182                 "creator", "updater");
183         assertFalse(DateUtils.isExpired(newAccessTokenDO.getExpiresTime()));
184         // 断言,新的访问令牌的缓存
185         OAuth2AccessTokenDO redisAccessTokenDO = oauth2AccessTokenRedisDAO.get(newAccessTokenDO.getAccessToken());
186         assertPojoEquals(newAccessTokenDO, redisAccessTokenDO, "createTime", "updateTime", "deleted");
187     }
188
189     @Test
190     public void testGetAccessToken() {
191         // mock 数据(访问令牌)
192         OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class)
193                 .setExpiresTime(LocalDateTime.now().plusDays(1));
194         oauth2AccessTokenMapper.insert(accessTokenDO);
195         // 准备参数
196         String accessToken = accessTokenDO.getAccessToken();
197
198         // 调用
199         OAuth2AccessTokenDO result = oauth2TokenService.getAccessToken(accessToken);
200         // 断言
201         assertPojoEquals(accessTokenDO, result, "createTime", "updateTime", "deleted",
202                 "creator", "updater");
203         assertPojoEquals(accessTokenDO, oauth2AccessTokenRedisDAO.get(accessToken), "createTime", "updateTime", "deleted",
204                 "creator", "updater");
205     }
206
207     @Test
208     public void testCheckAccessToken_null() {
209         // 调研,并断言
210         assertServiceException(() -> oauth2TokenService.checkAccessToken(randomString()),
211                 new ErrorCode(401, "访问令牌不存在"));
212     }
213
214     @Test
215     public void testCheckAccessToken_expired() {
216         // mock 数据(访问令牌)
217         OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class)
218                 .setExpiresTime(LocalDateTime.now().minusDays(1));
219         oauth2AccessTokenMapper.insert(accessTokenDO);
220         // 准备参数
221         String accessToken = accessTokenDO.getAccessToken();
222
223         // 调研,并断言
224         assertServiceException(() -> oauth2TokenService.checkAccessToken(accessToken),
225                 new ErrorCode(401, "访问令牌已过期"));
226     }
227
228     @Test
229     public void testCheckAccessToken_success() {
230         // mock 数据(访问令牌)
231         OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class)
232                 .setExpiresTime(LocalDateTime.now().plusDays(1));
233         oauth2AccessTokenMapper.insert(accessTokenDO);
234         // 准备参数
235         String accessToken = accessTokenDO.getAccessToken();
236
237         // 调研,并断言
238         OAuth2AccessTokenDO result = oauth2TokenService.getAccessToken(accessToken);
239         // 断言
240         assertPojoEquals(accessTokenDO, result, "createTime", "updateTime", "deleted",
241                 "creator", "updater");
242     }
243
244     @Test
245     public void testRemoveAccessToken_null() {
246         // 调用,并断言
247         assertNull(oauth2TokenService.removeAccessToken(randomString()));
248     }
249
250     @Test
251     public void testRemoveAccessToken_success() {
252         // mock 数据(访问令牌)
253         OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class)
254                 .setExpiresTime(LocalDateTime.now().plusDays(1));
255         oauth2AccessTokenMapper.insert(accessTokenDO);
256         // mock 数据(刷新令牌)
257         OAuth2RefreshTokenDO refreshTokenDO = randomPojo(OAuth2RefreshTokenDO.class)
258                 .setRefreshToken(accessTokenDO.getRefreshToken());
259         oauth2RefreshTokenMapper.insert(refreshTokenDO);
260         // 调用
261         OAuth2AccessTokenDO result = oauth2TokenService.removeAccessToken(accessTokenDO.getAccessToken());
262         assertPojoEquals(accessTokenDO, result, "createTime", "updateTime", "deleted",
263                 "creator", "updater");
264         // 断言数据
265         assertNull(oauth2AccessTokenMapper.selectByAccessToken(accessTokenDO.getAccessToken()));
266         assertNull(oauth2RefreshTokenMapper.selectByRefreshToken(accessTokenDO.getRefreshToken()));
267         assertNull(oauth2AccessTokenRedisDAO.get(accessTokenDO.getAccessToken()));
268     }
269
270
271     @Test
272     public void testGetAccessTokenPage() {
273         // mock 数据
274         OAuth2AccessTokenDO dbAccessToken = randomPojo(OAuth2AccessTokenDO.class, o -> { // 等会查询到
275             o.setUserId(10L);
276             o.setUserType(1);
277             o.setClientId("test_client");
278             o.setExpiresTime(LocalDateTime.now().plusDays(1));
279         });
280         oauth2AccessTokenMapper.insert(dbAccessToken);
281         // 测试 userId 不匹配
282         oauth2AccessTokenMapper.insert(cloneIgnoreId(dbAccessToken, o -> o.setUserId(20L)));
283         // 测试 userType 不匹配
284         oauth2AccessTokenMapper.insert(cloneIgnoreId(dbAccessToken, o -> o.setUserType(2)));
285         // 测试 userType 不匹配
286         oauth2AccessTokenMapper.insert(cloneIgnoreId(dbAccessToken, o -> o.setClientId("it_client")));
287         // 测试 expireTime 不匹配
288         oauth2AccessTokenMapper.insert(cloneIgnoreId(dbAccessToken, o -> o.setExpiresTime(LocalDateTimeUtil.now())));
289         // 准备参数
290         OAuth2AccessTokenPageReqVO reqVO = new OAuth2AccessTokenPageReqVO();
291         reqVO.setUserId(10L);
292         reqVO.setUserType(1);
293         reqVO.setClientId("test");
294
295         // 调用
296         PageResult<OAuth2AccessTokenDO> pageResult = oauth2TokenService.getAccessTokenPage(reqVO);
297         // 断言
298         assertEquals(1, pageResult.getTotal());
299         assertEquals(1, pageResult.getList().size());
300         assertPojoEquals(dbAccessToken, pageResult.getList().get(0));
301     }
302
303 }