houzhongyi
2024-07-11 e7c1260db32209a078a962aaa0ad5492c35774fb
提交 | 用户 | 时间
e7c126 1 package com.iailab.framework.datapermission.core.db;
H 2
3 import com.iailab.framework.common.util.collection.SetUtils;
4 import com.iailab.framework.datapermission.core.rule.DataPermissionRule;
5 import com.iailab.framework.datapermission.core.rule.DataPermissionRuleFactory;
6 import com.iailab.framework.mybatis.core.util.MyBatisUtils;
7 import com.iailab.framework.test.core.ut.BaseMockitoUnitTest;
8 import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
9 import net.sf.jsqlparser.expression.Alias;
10 import net.sf.jsqlparser.expression.Expression;
11 import net.sf.jsqlparser.expression.LongValue;
12 import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
13 import net.sf.jsqlparser.schema.Column;
14 import org.apache.ibatis.executor.Executor;
15 import org.apache.ibatis.executor.statement.StatementHandler;
16 import org.apache.ibatis.mapping.BoundSql;
17 import org.apache.ibatis.mapping.MappedStatement;
18 import org.junit.jupiter.api.BeforeEach;
19 import org.junit.jupiter.api.Test;
20 import org.mockito.InjectMocks;
21 import org.mockito.Mock;
22 import org.mockito.MockedStatic;
23
24 import java.sql.Connection;
25 import java.util.*;
26
27 import static java.util.Collections.singletonList;
28 import static org.junit.jupiter.api.Assertions.*;
29 import static org.mockito.Mockito.*;
30
31 /**
32  * {@link DataPermissionDatabaseInterceptor} 的单元测试
33  * 主要测试 {@link DataPermissionDatabaseInterceptor#beforePrepare(StatementHandler, Connection, Integer)}
34  * 和 {@link DataPermissionDatabaseInterceptor#beforeUpdate(Executor, MappedStatement, Object)}
35  * 以及在这个过程中,ContextHolder 和 MappedStatementCache
36  *
37  * @author iailab
38  */
39 public class DataPermissionDatabaseInterceptorTest extends BaseMockitoUnitTest {
40
41     @InjectMocks
42     private DataPermissionDatabaseInterceptor interceptor;
43
44     @Mock
45     private DataPermissionRuleFactory ruleFactory;
46
47     @BeforeEach
48     public void setUp() {
49         // 清理上下文
50         DataPermissionDatabaseInterceptor.ContextHolder.clear();
51         // 清空缓存
52         interceptor.getMappedStatementCache().clear();
53     }
54
55     @Test // 不存在规则,且不匹配
56     public void testBeforeQuery_withoutRule() {
57         try (MockedStatic<PluginUtils> pluginUtilsMock = mockStatic(PluginUtils.class)) {
58             // 准备参数
59             MappedStatement mappedStatement = mock(MappedStatement.class);
60             BoundSql boundSql = mock(BoundSql.class);
61
62             // 调用
63             interceptor.beforeQuery(null, mappedStatement, null, null, null, boundSql);
64             // 断言
65             pluginUtilsMock.verify(() -> PluginUtils.mpBoundSql(boundSql), never());
66         }
67     }
68
69     @Test // 存在规则,且不匹配
70     public void testBeforeQuery_withMatchRule() {
71         try (MockedStatic<PluginUtils> pluginUtilsMock = mockStatic(PluginUtils.class)) {
72             // 准备参数
73             MappedStatement mappedStatement = mock(MappedStatement.class);
74             BoundSql boundSql = mock(BoundSql.class);
75             // mock 方法(数据权限)
76             when(ruleFactory.getDataPermissionRule(same(mappedStatement.getId())))
77                     .thenReturn(singletonList(new DeptDataPermissionRule()));
78             // mock 方法(MPBoundSql)
79             PluginUtils.MPBoundSql mpBs = mock(PluginUtils.MPBoundSql.class);
80             pluginUtilsMock.when(() -> PluginUtils.mpBoundSql(same(boundSql))).thenReturn(mpBs);
81             // mock 方法(SQL)
82             String sql = "select * from t_user where id = 1";
83             when(mpBs.sql()).thenReturn(sql);
84             // 针对 ContextHolder 和 MappedStatementCache 暂时不 mock,主要想校验过程中,数据是否正确
85
86             // 调用
87             interceptor.beforeQuery(null, mappedStatement, null, null, null, boundSql);
88             // 断言
89             verify(mpBs, times(1)).sql(
90                     eq("SELECT * FROM t_user WHERE id = 1 AND t_user.dept_id = 100"));
91             // 断言缓存
92             assertTrue(interceptor.getMappedStatementCache().getNoRewritableMappedStatements().isEmpty());
93         }
94     }
95
96     @Test // 存在规则,但不匹配
97     public void testBeforeQuery_withoutMatchRule() {
98         try (MockedStatic<PluginUtils> pluginUtilsMock = mockStatic(PluginUtils.class)) {
99             // 准备参数
100             MappedStatement mappedStatement = mock(MappedStatement.class);
101             BoundSql boundSql = mock(BoundSql.class);
102             // mock 方法(数据权限)
103             when(ruleFactory.getDataPermissionRule(same(mappedStatement.getId())))
104                     .thenReturn(singletonList(new DeptDataPermissionRule()));
105             // mock 方法(MPBoundSql)
106             PluginUtils.MPBoundSql mpBs = mock(PluginUtils.MPBoundSql.class);
107             pluginUtilsMock.when(() -> PluginUtils.mpBoundSql(same(boundSql))).thenReturn(mpBs);
108             // mock 方法(SQL)
109             String sql = "select * from t_role where id = 1";
110             when(mpBs.sql()).thenReturn(sql);
111             // 针对 ContextHolder 和 MappedStatementCache 暂时不 mock,主要想校验过程中,数据是否正确
112
113             // 调用
114             interceptor.beforeQuery(null, mappedStatement, null, null, null, boundSql);
115             // 断言
116             verify(mpBs, times(1)).sql(
117                     eq("SELECT * FROM t_role WHERE id = 1"));
118             // 断言缓存
119             assertFalse(interceptor.getMappedStatementCache().getNoRewritableMappedStatements().isEmpty());
120         }
121     }
122
123     @Test
124     public void testAddNoRewritable() {
125         // 准备参数
126         MappedStatement ms = mock(MappedStatement.class);
127         List<DataPermissionRule> rules = singletonList(new DeptDataPermissionRule());
128         // mock 方法
129         when(ms.getId()).thenReturn("selectById");
130
131         // 调用
132         interceptor.getMappedStatementCache().addNoRewritable(ms, rules);
133         // 断言
134         Map<Class<? extends DataPermissionRule>, Set<String>> noRewritableMappedStatements =
135                 interceptor.getMappedStatementCache().getNoRewritableMappedStatements();
136         assertEquals(1, noRewritableMappedStatements.size());
137         assertEquals(SetUtils.asSet("selectById"), noRewritableMappedStatements.get(DeptDataPermissionRule.class));
138     }
139
140     @Test
141     public void testNoRewritable() {
142         // 准备参数
143         MappedStatement ms = mock(MappedStatement.class);
144         // mock 方法
145         when(ms.getId()).thenReturn("selectById");
146         // mock 数据
147         List<DataPermissionRule> rules = singletonList(new DeptDataPermissionRule());
148         interceptor.getMappedStatementCache().addNoRewritable(ms, rules);
149
150         // 场景一,rules 为空
151         assertTrue(interceptor.getMappedStatementCache().noRewritable(ms, null));
152         // 场景二,rules 非空,可重写
153         assertFalse(interceptor.getMappedStatementCache().noRewritable(ms, singletonList(new EmptyDataPermissionRule())));
154         // 场景三,rule 非空,不可重写
155         assertTrue(interceptor.getMappedStatementCache().noRewritable(ms, rules));
156     }
157
158     private static class DeptDataPermissionRule implements DataPermissionRule {
159
160         private static final String COLUMN = "dept_id";
161
162         @Override
163         public Set<String> getTableNames() {
164             return SetUtils.asSet("t_user");
165         }
166
167         @Override
168         public Expression getExpression(String tableName, Alias tableAlias) {
169             Column column = MyBatisUtils.buildColumn(tableName, tableAlias, COLUMN);
170             LongValue value = new LongValue(100L);
171             return new EqualsTo(column, value);
172         }
173
174     }
175
176     private static class EmptyDataPermissionRule implements DataPermissionRule {
177
178         @Override
179         public Set<String> getTableNames() {
180             return Collections.emptySet();
181         }
182
183         @Override
184         public Expression getExpression(String tableName, Alias tableAlias) {
185             return null;
186         }
187
188     }
189
190 }