houzhongyi
2024-07-11 e7c1260db32209a078a962aaa0ad5492c35774fb
提交 | 用户 | 时间
e7c126 1 package com.iailab.framework.datapermission.core.db;
H 2
3 import cn.hutool.core.collection.CollUtil;
4 import com.iailab.framework.common.util.collection.SetUtils;
5 import com.iailab.framework.datapermission.core.rule.DataPermissionRule;
6 import com.iailab.framework.datapermission.core.rule.DataPermissionRuleFactory;
7 import com.iailab.framework.mybatis.core.util.MyBatisUtils;
8 import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
9 import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
10 import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
11 import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
12 import lombok.Getter;
13 import lombok.RequiredArgsConstructor;
14 import net.sf.jsqlparser.expression.*;
15 import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
16 import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
17 import net.sf.jsqlparser.expression.operators.relational.ExistsExpression;
18 import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
19 import net.sf.jsqlparser.expression.operators.relational.InExpression;
20 import net.sf.jsqlparser.schema.Table;
21 import net.sf.jsqlparser.statement.delete.Delete;
22 import net.sf.jsqlparser.statement.select.*;
23 import net.sf.jsqlparser.statement.update.Update;
24 import org.apache.ibatis.executor.Executor;
25 import org.apache.ibatis.executor.statement.StatementHandler;
26 import org.apache.ibatis.mapping.BoundSql;
27 import org.apache.ibatis.mapping.MappedStatement;
28 import org.apache.ibatis.mapping.SqlCommandType;
29 import org.apache.ibatis.session.ResultHandler;
30 import org.apache.ibatis.session.RowBounds;
31
32 import java.sql.Connection;
33 import java.util.*;
34 import java.util.concurrent.ConcurrentHashMap;
35
36 /**
37  * 数据权限拦截器,通过 {@link DataPermissionRule} 数据权限规则,重写 SQL 的方式来实现
38  * 主要的 SQL 重写方法,可见 {@link #builderExpression(Expression, List)} 方法
39  *
40  * 整体的代码实现上,参考 {@link com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor} 实现。
41  * 所以每次 MyBatis Plus 升级时,需要 Review 下其具体的实现是否有变更!
42  *
43  * @author iailab
44  */
45 @RequiredArgsConstructor
46 public class DataPermissionDatabaseInterceptor extends JsqlParserSupport implements InnerInterceptor {
47
48     private final DataPermissionRuleFactory ruleFactory;
49
50     @Getter
51     private final MappedStatementCache mappedStatementCache = new MappedStatementCache();
52
53     @Override // SELECT 场景
54     public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
55         // 获得 Mapper 对应的数据权限的规则
56         List<DataPermissionRule> rules = ruleFactory.getDataPermissionRule(ms.getId());
57         if (mappedStatementCache.noRewritable(ms, rules)) { // 如果无需重写,则跳过
58             return;
59         }
60
61         PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
62         try {
63             // 初始化上下文
64             ContextHolder.init(rules);
65             // 处理 SQL
66             mpBs.sql(parserSingle(mpBs.sql(), null));
67         } finally {
68             // 添加是否需要重写的缓存
69             addMappedStatementCache(ms);
70             // 清空上下文
71             ContextHolder.clear();
72         }
73     }
74
75     @Override // 只处理 UPDATE / DELETE 场景,不处理 INSERT 场景(因为 INSERT 不需要数据权限)
76     public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
77         PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
78         MappedStatement ms = mpSh.mappedStatement();
79         SqlCommandType sct = ms.getSqlCommandType();
80         if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
81             // 获得 Mapper 对应的数据权限的规则
82             List<DataPermissionRule> rules = ruleFactory.getDataPermissionRule(ms.getId());
83             if (mappedStatementCache.noRewritable(ms, rules)) { // 如果无需重写,则跳过
84                 return;
85             }
86
87             PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
88             try {
89                 // 初始化上下文
90                 ContextHolder.init(rules);
91                 // 处理 SQL
92                 mpBs.sql(parserMulti(mpBs.sql(), null));
93             } finally {
94                 // 添加是否需要重写的缓存
95                 addMappedStatementCache(ms);
96                 // 清空上下文
97                 ContextHolder.clear();
98             }
99         }
100     }
101
102     @Override
103     protected void processSelect(Select select, int index, String sql, Object obj) {
104         processSelectBody(select.getSelectBody());
105         List<WithItem> withItemsList = select.getWithItemsList();
106         if (!CollectionUtils.isEmpty(withItemsList)) {
107             withItemsList.forEach(this::processSelectBody);
108         }
109     }
110
111     /**
112      * update 语句处理
113      */
114     @Override
115     protected void processUpdate(Update update, int index, String sql, Object obj) {
116         final Table table = update.getTable();
117         update.setWhere(this.builderExpression(update.getWhere(), table));
118     }
119
120     /**
121      * delete 语句处理
122      */
123     @Override
124     protected void processDelete(Delete delete, int index, String sql, Object obj) {
125         delete.setWhere(this.builderExpression(delete.getWhere(), delete.getTable()));
126     }
127
128     // ========== 和 TenantLineInnerInterceptor 一致的逻辑 ==========
129
130     protected void processSelectBody(SelectBody selectBody) {
131         if (selectBody == null) {
132             return;
133         }
134         if (selectBody instanceof PlainSelect) {
135             processPlainSelect((PlainSelect) selectBody);
136         } else if (selectBody instanceof WithItem) {
137             WithItem withItem = (WithItem) selectBody;
138             processSelectBody(withItem.getSubSelect().getSelectBody());
139         } else {
140             SetOperationList operationList = (SetOperationList) selectBody;
141             List<SelectBody> selectBodyList = operationList.getSelects();
142             if (CollectionUtils.isNotEmpty(selectBodyList)) {
143                 selectBodyList.forEach(this::processSelectBody);
144             }
145         }
146     }
147
148     /**
149      * 处理 PlainSelect
150      */
151     protected void processPlainSelect(PlainSelect plainSelect) {
152         //#3087 github
153         List<SelectItem> selectItems = plainSelect.getSelectItems();
154         if (CollectionUtils.isNotEmpty(selectItems)) {
155             selectItems.forEach(this::processSelectItem);
156         }
157
158         // 处理 where 中的子查询
159         Expression where = plainSelect.getWhere();
160         processWhereSubSelect(where);
161
162         // 处理 fromItem
163         FromItem fromItem = plainSelect.getFromItem();
164         List<Table> list = processFromItem(fromItem);
165         List<Table> mainTables = new ArrayList<>(list);
166
167         // 处理 join
168         List<Join> joins = plainSelect.getJoins();
169         if (CollectionUtils.isNotEmpty(joins)) {
170             mainTables = processJoins(mainTables, joins);
171         }
172
173         // 当有 mainTable 时,进行 where 条件追加
174         if (CollectionUtils.isNotEmpty(mainTables)) {
175             plainSelect.setWhere(builderExpression(where, mainTables));
176         }
177     }
178
179     private List<Table> processFromItem(FromItem fromItem) {
180         // 处理括号括起来的表达式
181         while (fromItem instanceof ParenthesisFromItem) {
182             fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
183         }
184
185         List<Table> mainTables = new ArrayList<>();
186         // 无 join 时的处理逻辑
187         if (fromItem instanceof Table) {
188             Table fromTable = (Table) fromItem;
189             mainTables.add(fromTable);
190         } else if (fromItem instanceof SubJoin) {
191             // SubJoin 类型则还需要添加上 where 条件
192             List<Table> tables = processSubJoin((SubJoin) fromItem);
193             mainTables.addAll(tables);
194         } else {
195             // 处理下 fromItem
196             processOtherFromItem(fromItem);
197         }
198         return mainTables;
199     }
200
201     /**
202      * 处理where条件内的子查询
203      * <p>
204      * 支持如下:
205      * 1. in
206      * 2. =
207      * 3. >
208      * 4. <
209      * 5. >=
210      * 6. <=
211      * 7. <>
212      * 8. EXISTS
213      * 9. NOT EXISTS
214      * <p>
215      * 前提条件:
216      * 1. 子查询必须放在小括号中
217      * 2. 子查询一般放在比较操作符的右边
218      *
219      * @param where where 条件
220      */
221     protected void processWhereSubSelect(Expression where) {
222         if (where == null) {
223             return;
224         }
225         if (where instanceof FromItem) {
226             processOtherFromItem((FromItem) where);
227             return;
228         }
229         if (where.toString().indexOf("SELECT") > 0) {
230             // 有子查询
231             if (where instanceof BinaryExpression) {
232                 // 比较符号 , and , or , 等等
233                 BinaryExpression expression = (BinaryExpression) where;
234                 processWhereSubSelect(expression.getLeftExpression());
235                 processWhereSubSelect(expression.getRightExpression());
236             } else if (where instanceof InExpression) {
237                 // in
238                 InExpression expression = (InExpression) where;
239                 Expression inExpression = expression.getRightExpression();
240                 if (inExpression instanceof SubSelect) {
241                     processSelectBody(((SubSelect) inExpression).getSelectBody());
242                 }
243             } else if (where instanceof ExistsExpression) {
244                 // exists
245                 ExistsExpression expression = (ExistsExpression) where;
246                 processWhereSubSelect(expression.getRightExpression());
247             } else if (where instanceof NotExpression) {
248                 // not exists
249                 NotExpression expression = (NotExpression) where;
250                 processWhereSubSelect(expression.getExpression());
251             } else if (where instanceof Parenthesis) {
252                 Parenthesis expression = (Parenthesis) where;
253                 processWhereSubSelect(expression.getExpression());
254             }
255         }
256     }
257
258     protected void processSelectItem(SelectItem selectItem) {
259         if (selectItem instanceof SelectExpressionItem) {
260             SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
261             if (selectExpressionItem.getExpression() instanceof SubSelect) {
262                 processSelectBody(((SubSelect) selectExpressionItem.getExpression()).getSelectBody());
263             } else if (selectExpressionItem.getExpression() instanceof Function) {
264                 processFunction((Function) selectExpressionItem.getExpression());
265             }
266         }
267     }
268
269     /**
270      * 处理函数
271      * <p>支持: 1. select fun(args..) 2. select fun1(fun2(args..),args..)<p>
272      * <p> fixed gitee pulls/141</p>
273      *
274      * @param function
275      */
276     protected void processFunction(Function function) {
277         ExpressionList parameters = function.getParameters();
278         if (parameters != null) {
279             parameters.getExpressions().forEach(expression -> {
280                 if (expression instanceof SubSelect) {
281                     processSelectBody(((SubSelect) expression).getSelectBody());
282                 } else if (expression instanceof Function) {
283                     processFunction((Function) expression);
284                 }
285             });
286         }
287     }
288
289     /**
290      * 处理子查询等
291      */
292     protected void processOtherFromItem(FromItem fromItem) {
293         // 去除括号
294         while (fromItem instanceof ParenthesisFromItem) {
295             fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
296         }
297
298         if (fromItem instanceof SubSelect) {
299             SubSelect subSelect = (SubSelect) fromItem;
300             if (subSelect.getSelectBody() != null) {
301                 processSelectBody(subSelect.getSelectBody());
302             }
303         } else if (fromItem instanceof ValuesList) {
304             logger.debug("Perform a subQuery, if you do not give us feedback");
305         } else if (fromItem instanceof LateralSubSelect) {
306             LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
307             if (lateralSubSelect.getSubSelect() != null) {
308                 SubSelect subSelect = lateralSubSelect.getSubSelect();
309                 if (subSelect.getSelectBody() != null) {
310                     processSelectBody(subSelect.getSelectBody());
311                 }
312             }
313         }
314     }
315
316     /**
317      * 处理 sub join
318      *
319      * @param subJoin subJoin
320      * @return Table subJoin 中的主表
321      */
322     private List<Table> processSubJoin(SubJoin subJoin) {
323         List<Table> mainTables = new ArrayList<>();
324         if (subJoin.getJoinList() != null) {
325             List<Table> list = processFromItem(subJoin.getLeft());
326             mainTables.addAll(list);
327             mainTables = processJoins(mainTables, subJoin.getJoinList());
328         }
329         return mainTables;
330     }
331
332     /**
333      * 处理 joins
334      *
335      * @param mainTables 可以为 null
336      * @param joins      join 集合
337      * @return List<Table> 右连接查询的 Table 列表
338      */
339     private List<Table> processJoins(List<Table> mainTables, List<Join> joins) {
340         // join 表达式中最终的主表
341         Table mainTable = null;
342         // 当前 join 的左表
343         Table leftTable = null;
344
345         if (mainTables == null) {
346             mainTables = new ArrayList<>();
347         } else if (mainTables.size() == 1) {
348             mainTable = mainTables.get(0);
349             leftTable = mainTable;
350         }
351
352         //对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名
353         Deque<List<Table>> onTableDeque = new LinkedList<>();
354         for (Join join : joins) {
355             // 处理 on 表达式
356             FromItem joinItem = join.getRightItem();
357
358             // 获取当前 join 的表,subJoint 可以看作是一张表
359             List<Table> joinTables = null;
360             if (joinItem instanceof Table) {
361                 joinTables = new ArrayList<>();
362                 joinTables.add((Table) joinItem);
363             } else if (joinItem instanceof SubJoin) {
364                 joinTables = processSubJoin((SubJoin) joinItem);
365             }
366
367             if (joinTables != null) {
368
369                 // 如果是隐式内连接
370                 if (join.isSimple()) {
371                     mainTables.addAll(joinTables);
372                     continue;
373                 }
374
375                 // 当前表是否忽略
376                 Table joinTable = joinTables.get(0);
377
378                 List<Table> onTables = null;
379                 // 如果不要忽略,且是右连接,则记录下当前表
380                 if (join.isRight()) {
381                     mainTable = joinTable;
382                     if (leftTable != null) {
383                         onTables = Collections.singletonList(leftTable);
384                     }
385                 } else if (join.isLeft()) {
386                     onTables = Collections.singletonList(joinTable);
387                 } else if (join.isInner()) {
388                     if (mainTable == null) {
389                         onTables = Collections.singletonList(joinTable);
390                     } else {
391                         onTables = Arrays.asList(mainTable, joinTable);
392                     }
393                     mainTable = null;
394                 }
395
396                 mainTables = new ArrayList<>();
397                 if (mainTable != null) {
398                     mainTables.add(mainTable);
399                 }
400
401                 // 获取 join 尾缀的 on 表达式列表
402                 Collection<Expression> originOnExpressions = join.getOnExpressions();
403                 // 正常 join on 表达式只有一个,立刻处理
404                 if (originOnExpressions.size() == 1 && onTables != null) {
405                     List<Expression> onExpressions = new LinkedList<>();
406                     onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables));
407                     join.setOnExpressions(onExpressions);
408                     leftTable = joinTable;
409                     continue;
410                 }
411                 // 表名压栈,忽略的表压入 null,以便后续不处理
412                 onTableDeque.push(onTables);
413                 // 尾缀多个 on 表达式的时候统一处理
414                 if (originOnExpressions.size() > 1) {
415                     Collection<Expression> onExpressions = new LinkedList<>();
416                     for (Expression originOnExpression : originOnExpressions) {
417                         List<Table> currentTableList = onTableDeque.poll();
418                         if (CollectionUtils.isEmpty(currentTableList)) {
419                             onExpressions.add(originOnExpression);
420                         } else {
421                             onExpressions.add(builderExpression(originOnExpression, currentTableList));
422                         }
423                     }
424                     join.setOnExpressions(onExpressions);
425                 }
426                 leftTable = joinTable;
427             } else {
428                 processOtherFromItem(joinItem);
429                 leftTable = null;
430             }
431         }
432
433         return mainTables;
434     }
435
436     // ========== 和 TenantLineInnerInterceptor 存在差异的逻辑:关键,实现权限条件的拼接 ==========
437
438     /**
439      * 处理条件
440      *
441      * @param currentExpression 当前 where 条件
442      * @param table             单个表
443      */
444     protected Expression builderExpression(Expression currentExpression, Table table) {
445         return this.builderExpression(currentExpression, Collections.singletonList(table));
446     }
447
448     /**
449      * 处理条件
450      *
451      * @param currentExpression 当前 where 条件
452      * @param tables 多个表
453      */
454     protected Expression builderExpression(Expression currentExpression, List<Table> tables) {
455         // 没有表需要处理直接返回
456         if (CollectionUtils.isEmpty(tables)) {
457             return currentExpression;
458         }
459
460         // 第一步,获得 Table 对应的数据权限条件
461         Expression dataPermissionExpression = null;
462         for (Table table : tables) {
463             // 构建每个表的权限 Expression 条件
464             Expression expression = buildDataPermissionExpression(table);
465             if (expression == null) {
466                 continue;
467             }
468             // 合并到 dataPermissionExpression 中
469             dataPermissionExpression = dataPermissionExpression == null ? expression
470                     : new AndExpression(dataPermissionExpression, expression);
471         }
472
473         // 第二步,合并多个 Expression 条件
474         if (dataPermissionExpression == null) {
475             return currentExpression;
476         }
477         if (currentExpression == null) {
478             return dataPermissionExpression;
479         }
480         // ① 如果表达式为 Or,则需要 (currentExpression) AND dataPermissionExpression
481         if (currentExpression instanceof OrExpression) {
482             return new AndExpression(new Parenthesis(currentExpression), dataPermissionExpression);
483         }
484         // ② 如果表达式为 And,则直接返回 where AND dataPermissionExpression
485         return new AndExpression(currentExpression, dataPermissionExpression);
486     }
487
488     /**
489      * 构建指定表的数据权限的 Expression 过滤条件
490      *
491      * @param table 表
492      * @return Expression 过滤条件
493      */
494     private Expression buildDataPermissionExpression(Table table) {
495         // 生成条件
496         Expression allExpression = null;
497         for (DataPermissionRule rule : ContextHolder.getRules()) {
498             // 判断表名是否匹配
499             String tableName = MyBatisUtils.getTableName(table);
500             if (!rule.getTableNames().contains(tableName)) {
501                 continue;
502             }
503             // 如果有匹配的规则,说明可重写。
504             // 为什么不是有 allExpression 非空才重写呢?在生成 column = value 过滤条件时,会因为 value 不存在,导致未重写。
505             // 这样导致第一次无 value,被标记成无需重写;但是第二次有 value,此时会需要重写。
506             ContextHolder.setRewrite(true);
507
508             // 单条规则的条件
509             Expression oneExpress = rule.getExpression(tableName, table.getAlias());
510             if (oneExpress == null){
511                 continue;
512             }
513             // 拼接到 allExpression 中
514             allExpression = allExpression == null ? oneExpress
515                     : new AndExpression(allExpression, oneExpress);
516         }
517
518         return allExpression;
519     }
520
521     /**
522      * 判断 SQL 是否重写。如果没有重写,则添加到 {@link MappedStatementCache} 中
523      *
524      * @param ms MappedStatement
525      */
526     private void addMappedStatementCache(MappedStatement ms) {
527         if (ContextHolder.getRewrite()) {
528             return;
529         }
530         // 无重写,进行添加
531         mappedStatementCache.addNoRewritable(ms, ContextHolder.getRules());
532     }
533
534     /**
535      * SQL 解析上下文,方便透传 {@link DataPermissionRule} 规则
536      *
537      * @author iailab
538      */
539     static final class ContextHolder {
540
541         /**
542          * 该 {@link MappedStatement} 对应的规则
543          */
544         private static final ThreadLocal<List<DataPermissionRule>> RULES = ThreadLocal.withInitial(Collections::emptyList);
545         /**
546          * SQL 是否进行重写
547          */
548         private static final ThreadLocal<Boolean> REWRITE = ThreadLocal.withInitial(() -> Boolean.FALSE);
549
550         public static void init(List<DataPermissionRule> rules) {
551             RULES.set(rules);
552             REWRITE.set(false);
553         }
554
555         public static void clear() {
556             RULES.remove();
557             REWRITE.remove();
558         }
559
560         public static boolean getRewrite() {
561             return REWRITE.get();
562         }
563
564         public static void setRewrite(boolean rewrite) {
565             REWRITE.set(rewrite);
566         }
567
568         public static List<DataPermissionRule> getRules() {
569             return RULES.get();
570         }
571
572     }
573
574     /**
575      * {@link MappedStatement} 缓存
576      * 目前主要用于,记录 {@link DataPermissionRule} 是否对指定 {@link MappedStatement} 无效
577      * 如果无效,则可以避免 SQL 的解析,加快速度
578      *
579      * @author iailab
580      */
581     static final class MappedStatementCache {
582
583         /**
584          * 指定数据权限规则,对指定 MappedStatement 无需重写(不生效)的缓存
585          *
586          * value:{@link MappedStatement#getId()} 编号
587          */
588         @Getter
589         private final Map<Class<? extends DataPermissionRule>, Set<String>> noRewritableMappedStatements = new ConcurrentHashMap<>();
590
591         /**
592          * 判断是否无需重写
593          * ps:虽然有点中文式英语,但是容易读懂即可
594          *
595          * @param ms MappedStatement
596          * @param rules 数据权限规则数组
597          * @return 是否无需重写
598          */
599         public boolean noRewritable(MappedStatement ms, List<DataPermissionRule> rules) {
600             // 如果规则为空,说明无需重写
601             if (CollUtil.isEmpty(rules)) {
602                 return true;
603             }
604             // 任一规则不在 noRewritableMap 中,则说明可能需要重写
605             for (DataPermissionRule rule : rules) {
606                 Set<String> mappedStatementIds = noRewritableMappedStatements.get(rule.getClass());
607                 if (!CollUtil.contains(mappedStatementIds, ms.getId())) {
608                     return false;
609                 }
610             }
611             return true;
612         }
613
614         /**
615          * 添加无需重写的 MappedStatement
616          *
617          * @param ms MappedStatement
618          * @param rules 数据权限规则数组
619          */
620         public void addNoRewritable(MappedStatement ms, List<DataPermissionRule> rules) {
621             for (DataPermissionRule rule : rules) {
622                 Set<String> mappedStatementIds = noRewritableMappedStatements.get(rule.getClass());
623                 if (CollUtil.isNotEmpty(mappedStatementIds)) {
624                     mappedStatementIds.add(ms.getId());
625                 } else {
626                     noRewritableMappedStatements.put(rule.getClass(), SetUtils.asSet(ms.getId()));
627                 }
628             }
629         }
630
631         /**
632          * 清空缓存
633          * 目前主要提供给单元测试
634          */
635         public void clear() {
636             noRewritableMappedStatements.clear();
637         }
638
639     }
640
641 }