houzhongyi
2024-07-11 e7c1260db32209a078a962aaa0ad5492c35774fb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
package com.iailab.framework.datapermission.core.db;
 
import com.iailab.framework.common.util.collection.SetUtils;
import com.iailab.framework.datapermission.core.rule.DataPermissionRule;
import com.iailab.framework.datapermission.core.rule.DataPermissionRuleFactory;
import com.iailab.framework.mybatis.core.util.MyBatisUtils;
import com.iailab.framework.test.core.ut.BaseMockitoUnitTest;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.schema.Column;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.MockedStatic;
 
import java.sql.Connection;
import java.util.*;
 
import static java.util.Collections.singletonList;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;
 
/**
 * {@link DataPermissionDatabaseInterceptor} 的单元测试
 * 主要测试 {@link DataPermissionDatabaseInterceptor#beforePrepare(StatementHandler, Connection, Integer)}
 * 和 {@link DataPermissionDatabaseInterceptor#beforeUpdate(Executor, MappedStatement, Object)}
 * 以及在这个过程中,ContextHolder 和 MappedStatementCache
 *
 * @author iailab
 */
public class DataPermissionDatabaseInterceptorTest extends BaseMockitoUnitTest {
 
    @InjectMocks
    private DataPermissionDatabaseInterceptor interceptor;
 
    @Mock
    private DataPermissionRuleFactory ruleFactory;
 
    @BeforeEach
    public void setUp() {
        // 清理上下文
        DataPermissionDatabaseInterceptor.ContextHolder.clear();
        // 清空缓存
        interceptor.getMappedStatementCache().clear();
    }
 
    @Test // 不存在规则,且不匹配
    public void testBeforeQuery_withoutRule() {
        try (MockedStatic<PluginUtils> pluginUtilsMock = mockStatic(PluginUtils.class)) {
            // 准备参数
            MappedStatement mappedStatement = mock(MappedStatement.class);
            BoundSql boundSql = mock(BoundSql.class);
 
            // 调用
            interceptor.beforeQuery(null, mappedStatement, null, null, null, boundSql);
            // 断言
            pluginUtilsMock.verify(() -> PluginUtils.mpBoundSql(boundSql), never());
        }
    }
 
    @Test // 存在规则,且不匹配
    public void testBeforeQuery_withMatchRule() {
        try (MockedStatic<PluginUtils> pluginUtilsMock = mockStatic(PluginUtils.class)) {
            // 准备参数
            MappedStatement mappedStatement = mock(MappedStatement.class);
            BoundSql boundSql = mock(BoundSql.class);
            // mock 方法(数据权限)
            when(ruleFactory.getDataPermissionRule(same(mappedStatement.getId())))
                    .thenReturn(singletonList(new DeptDataPermissionRule()));
            // mock 方法(MPBoundSql)
            PluginUtils.MPBoundSql mpBs = mock(PluginUtils.MPBoundSql.class);
            pluginUtilsMock.when(() -> PluginUtils.mpBoundSql(same(boundSql))).thenReturn(mpBs);
            // mock 方法(SQL)
            String sql = "select * from t_user where id = 1";
            when(mpBs.sql()).thenReturn(sql);
            // 针对 ContextHolder 和 MappedStatementCache 暂时不 mock,主要想校验过程中,数据是否正确
 
            // 调用
            interceptor.beforeQuery(null, mappedStatement, null, null, null, boundSql);
            // 断言
            verify(mpBs, times(1)).sql(
                    eq("SELECT * FROM t_user WHERE id = 1 AND t_user.dept_id = 100"));
            // 断言缓存
            assertTrue(interceptor.getMappedStatementCache().getNoRewritableMappedStatements().isEmpty());
        }
    }
 
    @Test // 存在规则,但不匹配
    public void testBeforeQuery_withoutMatchRule() {
        try (MockedStatic<PluginUtils> pluginUtilsMock = mockStatic(PluginUtils.class)) {
            // 准备参数
            MappedStatement mappedStatement = mock(MappedStatement.class);
            BoundSql boundSql = mock(BoundSql.class);
            // mock 方法(数据权限)
            when(ruleFactory.getDataPermissionRule(same(mappedStatement.getId())))
                    .thenReturn(singletonList(new DeptDataPermissionRule()));
            // mock 方法(MPBoundSql)
            PluginUtils.MPBoundSql mpBs = mock(PluginUtils.MPBoundSql.class);
            pluginUtilsMock.when(() -> PluginUtils.mpBoundSql(same(boundSql))).thenReturn(mpBs);
            // mock 方法(SQL)
            String sql = "select * from t_role where id = 1";
            when(mpBs.sql()).thenReturn(sql);
            // 针对 ContextHolder 和 MappedStatementCache 暂时不 mock,主要想校验过程中,数据是否正确
 
            // 调用
            interceptor.beforeQuery(null, mappedStatement, null, null, null, boundSql);
            // 断言
            verify(mpBs, times(1)).sql(
                    eq("SELECT * FROM t_role WHERE id = 1"));
            // 断言缓存
            assertFalse(interceptor.getMappedStatementCache().getNoRewritableMappedStatements().isEmpty());
        }
    }
 
    @Test
    public void testAddNoRewritable() {
        // 准备参数
        MappedStatement ms = mock(MappedStatement.class);
        List<DataPermissionRule> rules = singletonList(new DeptDataPermissionRule());
        // mock 方法
        when(ms.getId()).thenReturn("selectById");
 
        // 调用
        interceptor.getMappedStatementCache().addNoRewritable(ms, rules);
        // 断言
        Map<Class<? extends DataPermissionRule>, Set<String>> noRewritableMappedStatements =
                interceptor.getMappedStatementCache().getNoRewritableMappedStatements();
        assertEquals(1, noRewritableMappedStatements.size());
        assertEquals(SetUtils.asSet("selectById"), noRewritableMappedStatements.get(DeptDataPermissionRule.class));
    }
 
    @Test
    public void testNoRewritable() {
        // 准备参数
        MappedStatement ms = mock(MappedStatement.class);
        // mock 方法
        when(ms.getId()).thenReturn("selectById");
        // mock 数据
        List<DataPermissionRule> rules = singletonList(new DeptDataPermissionRule());
        interceptor.getMappedStatementCache().addNoRewritable(ms, rules);
 
        // 场景一,rules 为空
        assertTrue(interceptor.getMappedStatementCache().noRewritable(ms, null));
        // 场景二,rules 非空,可重写
        assertFalse(interceptor.getMappedStatementCache().noRewritable(ms, singletonList(new EmptyDataPermissionRule())));
        // 场景三,rule 非空,不可重写
        assertTrue(interceptor.getMappedStatementCache().noRewritable(ms, rules));
    }
 
    private static class DeptDataPermissionRule implements DataPermissionRule {
 
        private static final String COLUMN = "dept_id";
 
        @Override
        public Set<String> getTableNames() {
            return SetUtils.asSet("t_user");
        }
 
        @Override
        public Expression getExpression(String tableName, Alias tableAlias) {
            Column column = MyBatisUtils.buildColumn(tableName, tableAlias, COLUMN);
            LongValue value = new LongValue(100L);
            return new EqualsTo(column, value);
        }
 
    }
 
    private static class EmptyDataPermissionRule implements DataPermissionRule {
 
        @Override
        public Set<String> getTableNames() {
            return Collections.emptySet();
        }
 
        @Override
        public Expression getExpression(String tableName, Alias tableAlias) {
            return null;
        }
 
    }
 
}