NIO实现群聊

写在前面

此处使用的是单Reactor单线程模型,服务端使用一个线程处理多个客户端连接、消息接收并转发。

场景

使用nio实现一个群聊系统:

  • 服务端监控客户端动态上下线并实现消息打印及转发;

  • 客户端实习消息发送及接受其他客户端消息。

实现

服务端:

package others.nio.groupChat;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectableChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.Iterator;
import java.util.Set;

/**
 * Description: 群聊服务端
 * 实现客户端上下线提示及消息转发
 *
 * @author makeDoBetter
 * @version 1.0
 * @date 2021/4/12 14:39
 * @since JDK 1.8
 */
public class GroupChatServer {
    private ServerSocketChannel serverSocketChannel;
    private Selector selector;
    private static final int PORT = 1234;

    public static void main(String[] args) {
        GroupChatServer server = new GroupChatServer();
        server.handler();
    }

    private GroupChatServer() {
        try {
            serverSocketChannel = ServerSocketChannel.open();
            serverSocketChannel.configureBlocking(false);
            serverSocketChannel.socket().bind(new InetSocketAddress(PORT));
            selector = Selector.open();
            serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT);
        } catch (IOException e) {
            System.out.println("服务端异常");
            e.printStackTrace();
        }
    }

    private void handler() {
        try {
            while (true) {
                //阻塞两秒
                int count = selector.select(2000);
                if (count > 0) {
                    Iterator<SelectionKey> iterator = selector.selectedKeys().iterator();
                    while (iterator.hasNext()) {
                        SelectionKey key = iterator.next();
                        if (key.isAcceptable()) {
                            SocketChannel socketChannel = serverSocketChannel.accept();
                            System.out.println("客户端" + socketChannel.getRemoteAddress() + "上线");
                            socketChannel.configureBlocking(false);
                            socketChannel.register(selector, SelectionKey.OP_READ);
                        }
                        if (key.isReadable()) {
                            //读取数据
                            readInfo(key);
                        }
                        iterator.remove();
                    }
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            try {
                serverSocketChannel.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    private void readInfo(SelectionKey key) {
        ByteBuffer buffer = ByteBuffer.allocate(1024);
        SocketChannel channel = (SocketChannel) key.channel();
        try {
            int count = channel.read(buffer);
            if (count > 0) {
                //打印客户端信息
                String msg = new String(buffer.array());
                System.out.println(msg);
                //转发
                sentToOther(msg, channel);
            }
        } catch (IOException e) {
            try {
                System.out.println("客户端" + channel.getRemoteAddress() + "下线");
                key.cancel();
                channel.close();
            } catch (IOException ex) {
                ex.printStackTrace();
            }
        }
    }

    private void sentToOther(String msg, SocketChannel self) {
        Set<SelectionKey> keys = selector.keys();
        System.out.println("此时有" + keys.size() + "客户端在线");
        SelectableChannel channel = null;
        for (SelectionKey key : keys) {
            channel = key.channel();
            try {
                if (channel instanceof SocketChannel && channel != self) {
                    ByteBuffer buffer = ByteBuffer.wrap(msg.getBytes());
                    SocketChannel dest = (SocketChannel) channel;
                    dest.write(buffer);
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }
}

客户端:

package others.nio.groupChat;


import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.util.Iterator;
import java.util.Scanner;

/**
 * Description:群聊客户端
 * 实现消息的发送与接受
 *
 * @author makeDoBetter
 * @version 1.0
 * @date 2021/4/12 16:01
 * @since JDK 1.8
 */
public class GroupChatClient {
    private SocketChannel socketChannel;
    private Selector selector;
    private static final String HOST = "127.0.0.1";
    private static final int PORT = 1234;

    public GroupChatClient() {
        try {
            socketChannel = SocketChannel.open(new InetSocketAddress(HOST, PORT));
            socketChannel.configureBlocking(false);
            selector = Selector.open();
            socketChannel.register(selector, SelectionKey.OP_READ);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void sentInfo(String msg) {
        String info;
        try {
            info = socketChannel.getLocalAddress().toString() + "说:" + msg;
            socketChannel.write(ByteBuffer.wrap(info.getBytes()));
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void readInfo() {
        try {
            int count = selector.select();
            if (count > 0){
                Iterator<SelectionKey> iterator = selector.selectedKeys().iterator();
                if (iterator.hasNext()) {
                    SelectionKey key = iterator.next();
                    if (key.isReadable()) {
                        ByteBuffer buffer = ByteBuffer.allocate(1024);
                        SocketChannel channel = (SocketChannel) key.channel();
                        channel.read(buffer);
                        System.out.println(new String(buffer.array()));
                    }
                    iterator.remove();
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] args) {
        GroupChatClient groupClient = new GroupChatClient();
        try {
            System.out.println(groupClient.socketChannel.getLocalAddress().toString());

            new Thread(new Runnable() {
                @Override
                public void run() {
                    while (true){
                        groupClient.readInfo();
                        try {
                            Thread.sleep(1000);
                        } catch (InterruptedException e) {
                            e.printStackTrace();
                        }
                    }
                }
            }).start();

            Scanner scanner = new Scanner(System.in);
            while (scanner.hasNextLine()) {
                groupClient.sentInfo(scanner.nextLine());
            }
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            try {
                groupClient.socketChannel.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }
}

测试

启动了服务端与三个客户端,每个客户端发送一条消息,后关闭一个客户端,效果如下:

1. 客户端展示其他客户端发送的消息

客户端1

客户端2

客户端3

服务端展示

踩坑

如下消息转发方法中,不可以在循环外定义ByteBuffer对象进行循环体复用,由于通道读取缓冲区后会修改ByteBuffer的标志位,导致多个客户端通道写入只有一个能成功写入消息。

如果此处不理解,可了解ByteBuffer各标志位及读写操作时标志位变动。

private void sentToOther(String msg, SocketChannel self){
    Set<SelectionKey> keys = selector.keys();
    System.out.println("此时有" + keys.size() + "客户端在线");
    //不可在此定义复用ByteBuffer对象
    //ByteBuffer buffer = ByteBuffer.wrap(msg.getBytes());
    SelectableChannel channel = null;
    for (SelectionKey key : keys) {
        channel = key.channel();
        try {
            if (channel instanceof SocketChannel && channel != self){
                //每一个通道需要new一个ByteBuffer对象,否则将会出现缓冲区只能读出一次
                //具体原因由于ByteBuffer的标志位发生变动
                ByteBuffer buffer = ByteBuffer.wrap(msg.getBytes());
                SocketChannel dest = (SocketChannel)channel;
                dest.write(buffer);
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}
原文地址:https://www.cnblogs.com/fjrgg/p/14652577.html