dengzedong
2025-01-03 c9e48bd2dff2b5766589024cf7264189b5f2a05c
提交 | 用户 | 时间
e7c126 1 package com.iailab.framework.websocket.core.session;
H 2
3 import cn.hutool.core.collection.CollUtil;
4 import com.iailab.framework.security.core.LoginUser;
5 import com.iailab.framework.tenant.core.context.TenantContextHolder;
6 import com.iailab.framework.websocket.core.util.WebSocketFrameworkUtils;
7 import org.springframework.web.socket.WebSocketSession;
8
9 import java.util.ArrayList;
10 import java.util.Collection;
11 import java.util.LinkedList;
12 import java.util.List;
13 import java.util.concurrent.ConcurrentHashMap;
14 import java.util.concurrent.ConcurrentMap;
15 import java.util.concurrent.CopyOnWriteArrayList;
16
17 /**
18  * 默认的 {@link WebSocketSessionManager} 实现类
19  *
20  * @author iailab
21  */
22 public class WebSocketSessionManagerImpl implements WebSocketSessionManager {
23
24     /**
25      * id 与 WebSocketSession 映射
26      *
27      * key:Session 编号
28      */
29     private final ConcurrentMap<String, WebSocketSession> idSessions = new ConcurrentHashMap<>();
30
31     /**
32      * user 与 WebSocketSession 映射
33      *
34      * key1:用户类型
35      * key2:用户编号
36      */
37     private final ConcurrentMap<Integer, ConcurrentMap<Long, CopyOnWriteArrayList<WebSocketSession>>> userSessions
38             = new ConcurrentHashMap<>();
39
40     @Override
41     public void addSession(WebSocketSession session) {
42         // 添加到 idSessions 中
43         idSessions.put(session.getId(), session);
44         // 添加到 userSessions 中
45         LoginUser user = WebSocketFrameworkUtils.getLoginUser(session);
46         if (user == null) {
47             return;
48         }
49         ConcurrentMap<Long, CopyOnWriteArrayList<WebSocketSession>> userSessionsMap = userSessions.get(user.getUserType());
50         if (userSessionsMap == null) {
51             userSessionsMap = new ConcurrentHashMap<>();
52             if (userSessions.putIfAbsent(user.getUserType(), userSessionsMap) != null) {
53                 userSessionsMap = userSessions.get(user.getUserType());
54             }
55         }
56         CopyOnWriteArrayList<WebSocketSession> sessions = userSessionsMap.get(user.getId());
57         if (sessions == null) {
58             sessions = new CopyOnWriteArrayList<>();
59             if (userSessionsMap.putIfAbsent(user.getId(), sessions) != null) {
60                 sessions = userSessionsMap.get(user.getId());
61             }
62         }
63         sessions.add(session);
64     }
65
66     @Override
67     public void removeSession(WebSocketSession session) {
68         // 移除从 idSessions 中
69         idSessions.remove(session.getId());
70         // 移除从 idSessions 中
71         LoginUser user = WebSocketFrameworkUtils.getLoginUser(session);
72         if (user == null) {
73             return;
74         }
75         ConcurrentMap<Long, CopyOnWriteArrayList<WebSocketSession>> userSessionsMap = userSessions.get(user.getUserType());
76         if (userSessionsMap == null) {
77             return;
78         }
79         CopyOnWriteArrayList<WebSocketSession> sessions = userSessionsMap.get(user.getId());
80         sessions.removeIf(session0 -> session0.getId().equals(session.getId()));
81         if (CollUtil.isEmpty(sessions)) {
82             userSessionsMap.remove(user.getId(), sessions);
83         }
84     }
85
86     @Override
87     public WebSocketSession getSession(String id) {
88         return idSessions.get(id);
89     }
90
91     @Override
92     public Collection<WebSocketSession> getSessionList(Integer userType) {
93         ConcurrentMap<Long, CopyOnWriteArrayList<WebSocketSession>> userSessionsMap = userSessions.get(userType);
94         if (CollUtil.isEmpty(userSessionsMap)) {
95             return new ArrayList<>();
96         }
97         LinkedList<WebSocketSession> result = new LinkedList<>(); // 避免扩容
98         Long contextTenantId = TenantContextHolder.getTenantId();
99         for (List<WebSocketSession> sessions : userSessionsMap.values()) {
100             if (CollUtil.isEmpty(sessions)) {
101                 continue;
102             }
103             // 特殊:如果租户不匹配,则直接排除
104             if (contextTenantId != null) {
105                 Long userTenantId = WebSocketFrameworkUtils.getTenantId(sessions.get(0));
106                 if (!contextTenantId.equals(userTenantId)) {
107                     continue;
108                 }
109             }
110             result.addAll(sessions);
111         }
112         return result;
113     }
114
115     @Override
116     public Collection<WebSocketSession> getSessionList(Integer userType, Long userId) {
117         ConcurrentMap<Long, CopyOnWriteArrayList<WebSocketSession>> userSessionsMap = userSessions.get(userType);
118         if (CollUtil.isEmpty(userSessionsMap)) {
119             return new ArrayList<>();
120         }
121         CopyOnWriteArrayList<WebSocketSession> sessions = userSessionsMap.get(userId);
122         return CollUtil.isNotEmpty(sessions) ? new ArrayList<>(sessions) : new ArrayList<>();
123     }
124
125 }