webSocket+jwt实现方式

背景:

原项目是通过前端定时器获取消息,存在消息滞后、空刷服务器、浪费带宽和资源的问题,在springboot项目集成websocket可以实现实时点对点消息推送。

原项目是在header添加jwt令牌实现认证,由于websocket不支持在头部添加信息(或许是我打开的方式不对?),最终只能采用在url添加令牌参数实现认证,感觉不够优雅,后续再想办法重构改进。

ps:至于放行websocket相关url,完全不要去考虑,危害巨大。

1、websocket核心依赖

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-websocket</artifactId>
        </dependency>

2、config

@Configuration
public class WebSocketConfig {
    @Bean
    public ServerEndpointExporter serverEndpointExporter() {
        return new ServerEndpointExporter();
    }
}

3、WebSocketServer

@Slf4j
@ServerEndpoint("/webSocket/{code}")
@Component
public class WebSocketServer {
    /**
     * concurrent包的线程安全Set,用来存放每个客户端对应的WebSocket对象。
     */
    private static CopyOnWriteArraySet<WebSocketServer> webSocketSet = new CopyOnWriteArraySet<>();

    /**
     * 与客户端的连接会话,需要通过它来给客户端发送数据
     */
    private Session session;

    /**
     * 接收识别码
     */
    private String code = "";

    /**
     * 连接建立成功调用的方法
     */
    @OnOpen
    public void onOpen(Session session, @PathParam("code") String code) {
        this.session = session;
        //如果存在就先删除一个,防止重复推送消息,实际这里实现了set,不删除问题也不大
        webSocketSet.removeIf(webSocket -> webSocket.code.equals(code));
        webSocketSet.add(this);
        this.code = code;
        log.info("建立WebSocket连接,code:" + code+",当前连接数:"+webSocketSet.size());
    }

    /**
     * 连接关闭调用的方法
     */
    @OnClose
    public void onClose() {
        webSocketSet.remove(this);
        log.info("关闭WebSocket连接,code:" + this.code+",当前连接数:"+webSocketSet.size());
    }

    /**
     * 收到客户端消息后调用的方法
     *
     * @param message 客户端发送过来的消息
     */
    @OnMessage
    public void onMessage(String message, Session session) {
        log.info("收到来[" + code + "]的信息:" + message);

    }

    @OnError
    public void onError(Session session, Throwable error) {
        log.error("websocket发生错误");
        error.printStackTrace();
    }

    /**
     * 实现服务器主动推送
     */
    private void sendMessage(String message) throws IOException {
        this.session.getBasicRemote().sendText(message);
    }


    /**
     * 群发自定义消息
     */
    public void sendAll(String message) {
        log.info("推送消息到" + code + ",推送内容:" + message);
        for (WebSocketServer item : webSocketSet) {
            try {
                item.sendMessage(message);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    /**
     * 定点推送
     */
    public void sendTo(String message, @PathParam("code") String code) {
        for (WebSocketServer item : webSocketSet) {
            try {
                if (item.code.equals(code)) {
                    log.info("推送消息到[" + code + "],推送内容:" + message);
                    item.sendMessage(message);
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || getClass() != o.getClass()) {
            return false;
        }
        WebSocketServer that = (WebSocketServer) o;
        return Objects.equals(session, that.session) &&
                Objects.equals(code, that.code);
    }

    @Override
    public int hashCode() {
        return Objects.hash(session, code);
    }
}

4、令牌过滤器

@Slf4j
@Component
public class JwtTokenFilter extends OncePerRequestFilter {

    @Resource
    JwtProperties jwtProperties;

    @Resource
    TokenProvider tokenProvider;

    @Resource
    OnlineUserService onlineUserService;

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        // http连接时,客户端应该是在头信息中携带令牌
        String authorizationHeader = request.getHeader(jwtProperties.getHeader());
        if(StringUtils.isBlank(authorizationHeader)) {
            // websocket连接时,令牌放在url参数上,以后重构
            authorizationHeader = request.getParameter(jwtProperties.getHeader());
        }

        String token = null;
        if(!StringUtils.isEmpty(authorizationHeader) && authorizationHeader.startsWith(jwtProperties.getTokenStartWith())){
            token = authorizationHeader.replace(jwtProperties.getTokenStartWith(),"");
        }
        //验证token
        if(StringUtils.isNotBlank(token) && tokenProvider.validateToken(token)){
            //验证token是否在缓存中
            OnlineUserDto onlineUserDto = onlineUserService.getOne(jwtProperties.getOnlineKey() + token);
            if(onlineUserDto!=null){
                Authentication authentication = tokenProvider.getAuthentication(token, request);
                SecurityContextHolder.getContext().setAuthentication(authentication);
                log.debug("set Authentication to security context for '{}', uri: {}", authentication.getName(), request.getRequestURI());
            }
        }

        filterChain.doFilter(request, response);
    }
}

5、在业务中调用方式(伪代码)

    @Resource
    private WebSocketServer webSocketServer;

    // 向客户端推送实时消息
    webSocketServer.sendTo(content, sysUser.getId());

6、前端,伪代码

    getMessageCount() {
      getMyMessageCount().then(res => {
        const count = res
        this.messageCount = count > 0 ? count : null
      })
    },
    initWebSocket() {
      const wsUri = process.env.VUE_APP_WS_API + '/webSocket/' + this.user.id + '?Authorization=' + getToken()
      this.websock = new WebSocket(wsUri)
      this.websock.onopen = this.webSocketOnOpen
      this.websock.onerror = this.webSocketOnError
      this.websock.onmessage = this.webSocketOnMessage
    },
    webSocketOnOpen(e) {
      console.log('websocket 已经连接', e)
    },
    webSocketOnError(e) {
      this.$notify({
        title: 'WebSocket连接发生错误',
        type: 'error',
        duration: 0
      })
    },
    webSocketOnMessage(e) {
      const data = e.data
      this.$notify({
        title: '',
        message: data,
        type: 'success',
        dangerouslyUseHTMLString: true,
        duration: 5500
      })
      this.getMessageCount()
    }
原文地址:https://www.cnblogs.com/asker009/p/13507877.html