Netty学习笔记(番外篇)

这一篇是 ChannelHandler 和 ChannelPipeline 的番外篇,主要从源码的角度来学习 ChannelHandler、ChannelHandler 和 ChannelPipeline 相互之间是如何建立联系和运行的。

一、添加 ChannelHandler

上一篇的 demo 中可以看到在初始化 Server 和 Client 的时候,都会通过 ChannelPipeline 的 addLast 方法将 ChannelHandler 添加进去

// Server.java

// 部分代码片段
ServerBootstrap serverBootstrap = new ServerBootstrap();
NioEventLoopGroup group = new NioEventLoopGroup();
serverBootstrap.group(group)
        .channel(NioServerSocketChannel.class)Channel
        .localAddress(new InetSocketAddress("localhost", 9999))
        .childHandler(new ChannelInitializer<SocketChannel>() {
            protected void initChannel(SocketChannel socketChannel) throws Exception {
                // 添加ChannelHandler
                socketChannel.pipeline().addLast(new OneChannelOutBoundHandler());

                socketChannel.pipeline().addLast(new OneChannelInBoundHandler());
                socketChannel.pipeline().addLast(new TwoChannelInBoundHandler());
            }
        });

在上面的代码片段中,socketChannel.pipeline()方法返回的是一个类型是 DefaultChannelPipeline 的实例,DefaultChannelPipeline 实现了 ChannelPipeline 接口

DefaultChannelPipeline 的 addLast 方法实现如下:

// DefaultChannelPipeline.java

@Override
public final ChannelPipeline addLast(ChannelHandler... handlers) {
    return addLast(null, handlers);
}

@Override
public final ChannelPipeline addLast(EventExecutorGroup executor, ChannelHandler... handlers) {
    ObjectUtil.checkNotNull(handlers, "handlers");

    for (ChannelHandler h: handlers) {
        if (h == null) {
            break;
        }
        addLast(executor, null, h);
    }

    return this;
}

经过一系列重载方法调用,最终进入到下面的 addLast 方法

// DefaultChannelPipeline.java

@Override
public final ChannelPipeline addLast(EventExecutorGroup group, String name, ChannelHandler handler) {
    final AbstractChannelHandlerContext newCtx;
    synchronized (this) {
        checkMultiplicity(handler);

        newCtx = newContext(group, filterName(name, handler), handler);

        addLast0(newCtx);

        // If the registered is false it means that the channel was not registered on an eventLoop yet.
        // In this case we add the context to the pipeline and add a task that will call
        // ChannelHandler.handlerAdded(...) once the channel is registered.
        if (!registered) {
            newCtx.setAddPending();
            callHandlerCallbackLater(newCtx, true);
            return this;
        }

        EventExecutor executor = newCtx.executor();
        if (!executor.inEventLoop()) {
            callHandlerAddedInEventLoop(newCtx, executor);
            return this;
        }
    }
    callHandlerAdded0(newCtx);
    return this;
}

在这个方法实现中,利用传进来的 ChannelHandler 在 newContext 创建了一个 AbstractChannelHandlerContext 对象。newContext 方法实现如下:

// DefaultChannelPipeline.java

private AbstractChannelHandlerContext newContext(EventExecutorGroup group, String name, ChannelHandler handler) {
    return new DefaultChannelHandlerContext(this, childExecutor(group), name, handler);
}

这里创建并返回了一个类型为 DefaultChannelHandlerContext 的对象。从传入的参数可以看到,在这里将 ChannelHandlerContext、ChannelPipeline(this)和 ChannelHandler 三者建立了关系。
最后再看看 addLast0 方法实现:

// DefaultChannelPipeline.java

private void addLast0(AbstractChannelHandlerContext newCtx) {
    AbstractChannelHandlerContext prev = tail.prev;
    newCtx.prev = prev;
    newCtx.next = tail;
    prev.next = newCtx;
    tail.prev = newCtx;
}

这里出现了 AbstractChannelHandlerContext 的两个属性 prev 和 next,而 DefaultChannelPipeline 有一个属性 tail。从实现逻辑上看起来像是建立了一个双向链表的结构。下面的代码片段是关于 tail 和另一个相关属性 head:

// DefaultChannelPipeline.java

public class DefaultChannelPipeline implements ChannelPipeline {
    final AbstractChannelHandlerContext head;
    final AbstractChannelHandlerContext tail;

    // ......
    protected DefaultChannelPipeline(Channel channel) {
        // ......

        tail = new TailContext(this);
        head = new HeadContext(this);

        head.next = tail;
        tail.prev = head;
    }

    // ......
}

// HeaderContext.java
final class HeadContext extends AbstractChannelHandlerContext implements ChannelOutboundHandler, ChannelInboundHandler {
    // ......

    @Override
    public ChannelHandler handler() {
        return this;
    }

    //......
}

// TailContext.java
final class TailContext extends AbstractChannelHandlerContext implements ChannelInboundHandler {
    // ......

    @Override
    public ChannelHandler handler() {
        return this;
    }

    // ......
}

DefaultChannelPipeline 内部维护了两个 AbstractChannelHandlerContext 类型的属性 head、tail,而这两个属性又都实现了 ChannelHandler 的子接口。构造方法里将这两个属性维护成了一个双向链表。结合上面的 addLast0 方法实现,可以知道在添加 ChannelHandler 的时候,其实是在对 ChannelPipeline 内部维护的双向链表做插入操作。
下面是 ChannelHandlerContext 相关类的结构

所以,对 ChannelPipeline 做 add 操作添加 ChannelHandler 后,内部结构大体是这样的:

所有的 ChannelHandlerContext 组成了一个双向链表,头部是 HeadContext,尾部是 TailContext,因为它们都实现了 ChannelHandler 接口,所以它们内部的 Handler 也是自己。每次添加一个 ChannelHandler,将会新创建一个 DefaultChannelHandler 关联,并按照一定的顺序插入到链表中。
在 AbstractChannelHandlerContext 类里有一个属性 executionMask,在构造方法初始化时会对它进行赋值

// AbstractChannelHandlerContext.java

// 省略部分代码

AbstractChannelHandlerContext(DefaultChannelPipeline pipeline, EventExecutor executor,
                                String name, Class<? extends ChannelHandler> handlerClass) {
    this.name = ObjectUtil.checkNotNull(name, "name");
    this.pipeline = pipeline;
    this.executor = executor;
    this.executionMask = mask(handlerClass);
    // Its ordered if its driven by the EventLoop or the given Executor is an instanceof OrderedEventExecutor.
    ordered = executor == null || executor instanceof OrderedEventExecutor;
}

// 省略部分代码

mask 是一个静态方法,来自于 ChannelHandlerMask 类

// ChannelHandlerMask.java

// 省略部分代码

/**
* Return the {@code executionMask}.
*/
static int mask(Class<? extends ChannelHandler> clazz) {
    // Try to obtain the mask from the cache first. If this fails calculate it and put it in the cache for fast
    // lookup in the future.
    Map<Class<? extends ChannelHandler>, Integer> cache = MASKS.get();
    Integer mask = cache.get(clazz);
    if (mask == null) {
        mask = mask0(clazz);
        cache.put(clazz, mask);
    }
    return mask;
}

/**
* Calculate the {@code executionMask}.
*/
private static int mask0(Class<? extends ChannelHandler> handlerType) {
    int mask = MASK_EXCEPTION_CAUGHT;
    try {
        if (ChannelInboundHandler.class.isAssignableFrom(handlerType)) {
            mask |= MASK_ALL_INBOUND;

            if (isSkippable(handlerType, "channelRegistered", ChannelHandlerContext.class)) {
                mask &= ~MASK_CHANNEL_REGISTERED;
            }
            if (isSkippable(handlerType, "channelUnregistered", ChannelHandlerContext.class)) {
                mask &= ~MASK_CHANNEL_UNREGISTERED;
            }
            if (isSkippable(handlerType, "channelActive", ChannelHandlerContext.class)) {
                mask &= ~MASK_CHANNEL_ACTIVE;
            }
            if (isSkippable(handlerType, "channelInactive", ChannelHandlerContext.class)) {
                mask &= ~MASK_CHANNEL_INACTIVE;
            }
            if (isSkippable(handlerType, "channelRead", ChannelHandlerContext.class, Object.class)) {
                mask &= ~MASK_CHANNEL_READ;
            }
            if (isSkippable(handlerType, "channelReadComplete", ChannelHandlerContext.class)) {
                mask &= ~MASK_CHANNEL_READ_COMPLETE;
            }
            if (isSkippable(handlerType, "channelWritabilityChanged", ChannelHandlerContext.class)) {
                mask &= ~MASK_CHANNEL_WRITABILITY_CHANGED;
            }
            if (isSkippable(handlerType, "userEventTriggered", ChannelHandlerContext.class, Object.class)) {
                mask &= ~MASK_USER_EVENT_TRIGGERED;
            }
        }

        if (ChannelOutboundHandler.class.isAssignableFrom(handlerType)) {
            mask |= MASK_ALL_OUTBOUND;

            if (isSkippable(handlerType, "bind", ChannelHandlerContext.class,
                    SocketAddress.class, ChannelPromise.class)) {
                mask &= ~MASK_BIND;
            }
            if (isSkippable(handlerType, "connect", ChannelHandlerContext.class, SocketAddress.class,
                    SocketAddress.class, ChannelPromise.class)) {
                mask &= ~MASK_CONNECT;
            }
            if (isSkippable(handlerType, "disconnect", ChannelHandlerContext.class, ChannelPromise.class)) {
                mask &= ~MASK_DISCONNECT;
            }
            if (isSkippable(handlerType, "close", ChannelHandlerContext.class, ChannelPromise.class)) {
                mask &= ~MASK_CLOSE;
            }
            if (isSkippable(handlerType, "deregister", ChannelHandlerContext.class, ChannelPromise.class)) {
                mask &= ~MASK_DEREGISTER;
            }
            if (isSkippable(handlerType, "read", ChannelHandlerContext.class)) {
                mask &= ~MASK_READ;
            }
            if (isSkippable(handlerType, "write", ChannelHandlerContext.class,
                    Object.class, ChannelPromise.class)) {
                mask &= ~MASK_WRITE;
            }
            if (isSkippable(handlerType, "flush", ChannelHandlerContext.class)) {
                mask &= ~MASK_FLUSH;
            }
        }

        if (isSkippable(handlerType, "exceptionCaught", ChannelHandlerContext.class, Throwable.class)) {
            mask &= ~MASK_EXCEPTION_CAUGHT;
        }
    } catch (Exception e) {
        // Should never reach here.
        PlatformDependent.throwException(e);
    }

    return mask;
}

@SuppressWarnings("rawtypes")
private static boolean isSkippable(
        final Class<?> handlerType, final String methodName, final Class<?>... paramTypes) throws Exception {
    return AccessController.doPrivileged(new PrivilegedExceptionAction<Boolean>() {
        @Override
        public Boolean run() throws Exception {
            Method m;
            try {
                m = handlerType.getMethod(methodName, paramTypes);
            } catch (NoSuchMethodException e) {
                if (logger.isDebugEnabled()) {
                    logger.debug(
                        "Class {} missing method {}, assume we can not skip execution", handlerType, methodName, e);
                }
                return false;
            }
            return m != null && m.isAnnotationPresent(Skip.class);
        }
    });
}

// 省略部分代码

以上代码实现逻辑是这样的:当创建一个 ChannelHandlerContext 时,会与一个 ChannelHandler 绑定,同时会将传递进来的 ChannelHandler 进行解析,解析当前 ChannelHandler 支持哪些回调方法,并通过位运算得到一个结果保存在 ChannelHandlerContext 的 executionMask 属性里。注意 m.isAnnotationPresent(Skip.class)这里,ChannelHandler 的基类 ChannelInboundHandlerAdapter 和 ChannelOutboundHandlerAdapter 里的回调方法上都有@Skip 注解,当继承了这两个类并重写了某个回调方法后,这个方法上的注解就会被覆盖掉,解析时就会被认为当前 ChannelHandler 支持这个回调方法。
下面是每个回调方法对应的掩码

// ChannelHandlerMask.java

final class ChannelHandlerMask {
    // Using to mask which methods must be called for a ChannelHandler.
    static final int MASK_EXCEPTION_CAUGHT = 1;
    static final int MASK_CHANNEL_REGISTERED = 1 << 1;
    static final int MASK_CHANNEL_UNREGISTERED = 1 << 2;
    static final int MASK_CHANNEL_ACTIVE = 1 << 3;
    static final int MASK_CHANNEL_INACTIVE = 1 << 4;
    static final int MASK_CHANNEL_READ = 1 << 5;
    static final int MASK_CHANNEL_READ_COMPLETE = 1 << 6;
    static final int MASK_USER_EVENT_TRIGGERED = 1 << 7;
    static final int MASK_CHANNEL_WRITABILITY_CHANGED = 1 << 8;
    static final int MASK_BIND = 1 << 9;
    static final int MASK_CONNECT = 1 << 10;
    static final int MASK_DISCONNECT = 1 << 11;
    static final int MASK_CLOSE = 1 << 12;
    static final int MASK_DEREGISTER = 1 << 13;
    static final int MASK_READ = 1 << 14;
    static final int MASK_WRITE = 1 << 15;
    static final int MASK_FLUSH = 1 << 16;

    static final int MASK_ONLY_INBOUND =  MASK_CHANNEL_REGISTERED |
            MASK_CHANNEL_UNREGISTERED | MASK_CHANNEL_ACTIVE | MASK_CHANNEL_INACTIVE | MASK_CHANNEL_READ |
            MASK_CHANNEL_READ_COMPLETE | MASK_USER_EVENT_TRIGGERED | MASK_CHANNEL_WRITABILITY_CHANGED;
    private static final int MASK_ALL_INBOUND = MASK_EXCEPTION_CAUGHT | MASK_ONLY_INBOUND;
    static final int MASK_ONLY_OUTBOUND =  MASK_BIND | MASK_CONNECT | MASK_DISCONNECT |
            MASK_CLOSE | MASK_DEREGISTER | MASK_READ | MASK_WRITE | MASK_FLUSH;
    private static final int MASK_ALL_OUTBOUND = MASK_EXCEPTION_CAUGHT | MASK_ONLY_OUTBOUND;
}

二、ChannelHandler 处理消息

我们以消息读取和写入为例,来看看在 ChannelPipeline 里的各个 ChannelHandler 是如何按照顺序处理消息和事件的。

读取消息

当 Channel 读取到消息后,会在以下地方调用 ChannelPipeline 的 fireChannelRead 方法:

// AbstractNioMessageClient.java

private final class NioMessageUnsafe extends AbstractNioUnsafe {

    // 省略代码

    @Override
    public void read() {
        // ......

        for (int i = 0; i < size; i ++) {
            readPending = false;
            pipeline.fireChannelRead(readBuf.get(i));
        }

        // ......
    }

    // 省略代码
}

// DefaultChannelPipeline.java

// 省略代码

@Override
public final ChannelPipeline fireChannelRead(Object msg) {
    AbstractChannelHandlerContext.invokeChannelRead(head, msg);
    return this;
}

// 省略代码

可以看到,通过 AbstractChannelHandlerContext 的 invokeChannelRead 方法,传递 head,从头部开始触发读取事件。

// AbstractChannelHandlerContext.java

// 省略代码

static void invokeChannelRead(final AbstractChannelHandlerContext next, Object msg) {
    final Object m = next.pipeline.touch(ObjectUtil.checkNotNull(msg, "msg"), next);
    EventExecutor executor = next.executor();
    if (executor.inEventLoop()) {
        next.invokeChannelRead(m);
    } else {
        executor.execute(new Runnable() {
            @Override
            public void run() {
                next.invokeChannelRead(m);
            }
        });
    }
}

private void invokeChannelRead(Object msg) {
    if (invokeHandler()) {
        try {
            ((ChannelInboundHandler) handler()).channelRead(this, msg);
        } catch (Throwable t) {
            invokeExceptionCaught(t);
        }
    } else {
        fireChannelRead(msg);
    }
}

/**
    * Makes best possible effort to detect if {@link ChannelHandler#handlerAdded(ChannelHandlerContext)} was called
    * yet. If not return {@code false} and if called or could not detect return {@code true}.
    *
    * If this method returns {@code false} we will not invoke the {@link ChannelHandler} but just forward the event.
    * This is needed as {@link DefaultChannelPipeline} may already put the {@link ChannelHandler} in the linked-list
    * but not called {@link ChannelHandler#handlerAdded(ChannelHandlerContext)}.
    */
private boolean invokeHandler() {
    // Store in local variable to reduce volatile reads.
    int handlerState = this.handlerState;
    return handlerState == ADD_COMPLETE || (!ordered && handlerState == ADD_PENDING);
}

// 省略代码

在这里通过 invokeHandler 方法对当前 ChannelHandler 进行状态检查,通过了就将调用当前 ChannelHandler 的 channelRead 方法,没有通过将调用 fireChannelRead 方法将事件传递到下一个 ChannelHandler 上。而 head 的类型是 HeadContext,本身也实现了 ChannelInBoundHandler 接口,所以这里调用的是 HeadContext 的 channelRead 方法。

// DefaultChannelPipeline.java

final class HeadContext extends AbstractChannelHandlerContext implements ChannelOutboundHandler, ChannelInboundHandler {
    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) {
        ctx.fireChannelRead(msg);
    }
}

这里对消息没有做任何处理,直接将读取消息传递下去。接下来看看 ChannelHandlerContext 的 fireChannelRead 做了什么

// AbstractChannelHandlerContext.java

@Override
public ChannelHandlerContext fireChannelRead(final Object msg) {
    invokeChannelRead(findContextInbound(MASK_CHANNEL_READ), msg);
    return this;
}

private AbstractChannelHandlerContext findContextInbound(int mask) {
    AbstractChannelHandlerContext ctx = this;
    EventExecutor currentExecutor = executor();
    do {
        ctx = ctx.next;
    } while (skipContext(ctx, currentExecutor, mask, MASK_ONLY_INBOUND));
    return ctx;
}

private static boolean skipContext(
        AbstractChannelHandlerContext ctx, EventExecutor currentExecutor, int mask, int onlyMask) {
    // Ensure we correctly handle MASK_EXCEPTION_CAUGHT which is not included in the MASK_EXCEPTION_CAUGHT
    return (ctx.executionMask & (onlyMask | mask)) == 0 ||
            // We can only skip if the EventExecutor is the same as otherwise we need to ensure we offload
            // everything to preserve ordering.
            //
            // See https://github.com/netty/netty/issues/10067
            (ctx.executor() == currentExecutor && (ctx.executionMask & mask) == 0);
}

这里实现的逻辑是这样的:在双向链表中,从当前 ChannelHandlerContext 节点向后寻找,直到找到匹配 MASK_CHANNEL_READ 这个掩码的 ChannelHandlerContext。从上面的章节里可以直到 ChannelHandlerContext 的属性里保存了当前 ChannelHandler 支持(重写)的所有方法掩码的位运算值,通过位运算的结果来找到实现了对应方法的最近的 ChannelHandlerContext。
链表最后一个节点是 TailContext

// DefaultChannelPipeline.java

final class TailContext extends AbstractChannelHandlerContext implements ChannelInboundHandler {
    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) {
        onUnhandledInboundMessage(ctx, msg);
    }

    /**
     * Called once a message hit the end of the {@link ChannelPipeline} without been handled by the user
     * in {@link ChannelInboundHandler#channelRead(ChannelHandlerContext, Object)}. This method is responsible
     * to call {@link ReferenceCountUtil#release(Object)} on the given msg at some point.
     */
    protected void onUnhandledInboundMessage(ChannelHandlerContext ctx, Object msg) {
        onUnhandledInboundMessage(msg);
        if (logger.isDebugEnabled()) {
            logger.debug("Discarded message pipeline : {}. Channel : {}.",
                         ctx.pipeline().names(), ctx.channel());
        }
    }

    /**
     * Called once a message hit the end of the {@link ChannelPipeline} without been handled by the user
     * in {@link ChannelInboundHandler#channelRead(ChannelHandlerContext, Object)}. This method is responsible
     * to call {@link ReferenceCountUtil#release(Object)} on the given msg at some point.
     */
    protected void onUnhandledInboundMessage(Object msg) {
        try {
            logger.debug(
                    "Discarded inbound message {} that reached at the tail of the pipeline. " +
                            "Please check your pipeline configuration.", msg);
        } finally {
            ReferenceCountUtil.release(msg);
        }
    }
}

可以看到,tail 节点的 channelRead 方法没有将事件继续传递下去,只是释放了 msg。

写入消息

我们通过 OneChannelInBoundHandler 的 channelReadComplete 方法里的 ctx.write 方法来看

// AbstractChannelHandlerContext.java

// 省略代码

@Override
public ChannelFuture write(Object msg) {
    return write(msg, newPromise());
}

@Override
public ChannelFuture write(final Object msg, final ChannelPromise promise) {
    write(msg, false, promise);

    return promise;
}

private void write(Object msg, boolean flush, ChannelPromise promise) {
    ObjectUtil.checkNotNull(msg, "msg");
    try {
        if (isNotValidPromise(promise, true)) {
            ReferenceCountUtil.release(msg);
            // cancelled
            return;
        }
    } catch (RuntimeException e) {
        ReferenceCountUtil.release(msg);
        throw e;
    }

    final AbstractChannelHandlerContext next = findContextOutbound(flush ?
            (MASK_WRITE | MASK_FLUSH) : MASK_WRITE);
    final Object m = pipeline.touch(msg, next);
    EventExecutor executor = next.executor();
    if (executor.inEventLoop()) {
        if (flush) {
            next.invokeWriteAndFlush(m, promise);
        } else {
            next.invokeWrite(m, promise);
        }
    } else {
        final WriteTask task = WriteTask.newInstance(next, m, promise, flush);
        if (!safeExecute(executor, task, promise, m, !flush)) {
            // We failed to submit the WriteTask. We need to cancel it so we decrement the pending bytes
            // and put it back in the Recycler for re-use later.
            //
            // See https://github.com/netty/netty/issues/8343.
            task.cancel();
        }
    }
}

private AbstractChannelHandlerContext findContextOutbound(int mask) {
    AbstractChannelHandlerContext ctx = this;
    EventExecutor currentExecutor = executor();
    do {
        ctx = ctx.prev;
    } while (skipContext(ctx, currentExecutor, mask, MASK_ONLY_OUTBOUND));
    return ctx;
}

// 省略代码

通过调用一系列重载的 write 方法后,通过 findContextOutbound 方法在双向链表里向前寻找最近的实现了 write 或 writeAndFlush 方法的 ChannelHandlerContext,调用它的 invokeWrite 或 invokeWriteAndFlush 方法。

// AbstractChannelHandlerContext.java

// 省略代码

void invokeWrite(Object msg, ChannelPromise promise) {
    if (invokeHandler()) {
        invokeWrite0(msg, promise);
    } else {
        write(msg, promise);
    }
}

private void invokeWrite0(Object msg, ChannelPromise promise) {
    try {
        ((ChannelOutboundHandler) handler()).write(this, msg, promise);
    } catch (Throwable t) {
        notifyOutboundHandlerException(t, promise);
    }
}

// 省略代码

同理于读取消息,这里经过 invokeHandler 方法检查通过后调用找到的 ChannelHandlerContext 的 ChannelHandler,没有通过检查,则继续向前传递写入事件。当写入消息传递到头部,调用 HeadContext 的 write 方法

// DefaultChannelPipeline.java

final class HeadContext extends AbstractChannelHandlerContext implements ChannelOutboundHandler, ChannelInboundHandler {

    private final Unsafe unsafe;

    // 省略代码

    @Override
    public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
        unsafe.write(msg, promise);
    }

    @Override
    public void flush(ChannelHandlerContext ctx) {
        unsafe.flush();
    }

    // 省略代码
}

最终通过调用 unsafe 的 write 方法写入消息。
最后,从上面的实现里可以发现,在将 ChannelHandler 加入到 ChannelPipeline 时,要把 ChannelOutBoundHandler 类型的 ChannelHandler 进来添加在前面,否则在 ChannelInBoundHandler 写入消息时,在它后面的 ChannelOutBoundHandler 将无法获取到事件。

原文地址:https://www.cnblogs.com/niklai/p/12995564.html