基于 Netty 实现 RPC

  1. 了解 RPC 工作机制
  2. 了解 RPC 报文(以 Dubbo 为例)
  3. 需求分析
  4. 详细设计
  5. 编码

前面通过 Redis 客户端和 WebSocket 服务器已经巩固了 Netty 的实战技能

接下来再做一个额外拓展

了解 RPC 工作机制

在一个典型 RPC 的使用场景中,包含了

  • 服务发现
  • 负载
  • 容错
  • 网络传输
  • 序列化

一般,远程过程调用RPC就是本地动态代理隐藏通信细节,通过组件序列化请求,走网络到服务端,执行真正的服务代码,然后将结果返回给客户端,反序列化数据给调用方法的过程。

RPC具体调用流程如下所示:

serviceClient:这个模块主要是封装服务端对外提供的 API,让客户端像使用本地 API 接口一样调用远程服务。一般,我们使用动态代理机制,当客户端调用 api 的方法时,serviceClient 会走代理逻辑,去远程服务器请求真正的执行方法,然后将响应结果作为本地的 api 方法执行结果返回给客户端应用。类似 RMI 的 stub 模块。

processor:在服务端存在很多方法,当客户端请求过来,服务端需要定位到具体对象的具体方法,然后执行该方法,这个功能就由 processor 模块来完成。一般这个操作需要使用反射机制来获取用来执行真实处理逻辑的方法,当然,有的 RPC 直接在 server 初始化的时候,将一定规则写进 Map 映射中,这样直接获取对象即可。类似RMI的 skeleton 模块。

protocol:协议层,这是每个 RPC 组件的核心技术所在。一般,协议层包括编码/解码,或者说序列化和反序列化工作;当然,有的时候编解码不仅仅是对象序列化的工作,还有一些通信相关的字节流的额外解析部分。序列化工具有:hessian,protobuf,avro,thrift,json系,xml系等等。在 RMI 中这块是直接使用 JDK 自身的序列化组件。

transport:传输层,主要是服务端和客户端网络通信相关的功能。这里和下面的 IO 层区分开,主要是因为传输层处理 server/client 的网络通信交互,而不涉及具体底层处理连接请求和响应相关的逻辑。

I/O:这个模块主要是为了提高性能可能采用不同的 IO 模型和线程模型,当然,一般我们可能和上面的 transport 层联系的比较紧密,统一称为 remote 模块。

了解 RPC 报文(以 Dubbo 为例)

Dubbo 报文格式

Dubbo head

需求分析

我们的目的是掌握 Netty 的开发技巧,因此这里采用最小化实现原则来完成的一个 RPC demo

忽略负载均衡、容错、代理透明等需求

详细设计

编码

package com.shar.netty.netty.rpc;

import com.alibaba.dubbo.common.io.Bytes;
import com.shar.netty.netty.ByteBufUtil;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.ByteToMessageCodec;
import io.netty.util.concurrent.DefaultPromise;
import io.netty.util.concurrent.Promise;
import jdk.internal.org.objectweb.asm.Type;

import java.io.*;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Proxy;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicLong;

public class RpcDemo {

    public class RpcServer {
        ExecutorService threadPool = Executors.newFixedThreadPool(500);
        private Map<String, ServiceBean> register = new HashMap<>();

        public void start(int port) throws InterruptedException {
            ServerBootstrap bootstrap = new ServerBootstrap();
            EventLoopGroup boss = new NioEventLoopGroup(1);
            EventLoopGroup work = new NioEventLoopGroup(8);
            bootstrap.group(boss, work).channel(NioServerSocketChannel.class).childHandler(new ChannelInitializer<Channel>() {
                @Override
                protected void initChannel(Channel ch) throws Exception {
                    ch.pipeline().addLast("codec", new RpcCodec());
                    ch.pipeline().addLast("dispatch", new Dispatch());
                }
            }).bind(port).sync();
            System.out.println("服务启动成功");
        }

        private class Dispatch extends SimpleChannelInboundHandler<Transfer> {
            @Override
            protected void channelRead0(ChannelHandlerContext ctx, Transfer transfer) {
                if (transfer.heartbeat) { // 心跳处理
                    Transfer t = new Transfer(transfer.id);
                    t.heartbeat = true;
                    t.request = false;
                    ctx.writeAndFlush(t);// 返回心跳
                } else {
                    threadPool.submit(() -> {
                        Transfer to = doDispatchRequest(transfer);
                        ctx.writeAndFlush(to);// 非IO线程 异步提交到IO
                    });
                }
            }

            // 业务请求处理
            Transfer doDispatchRequest(Transfer from) {
                Request request = (Request) from.target;
                Transfer to = new Transfer(from.id);
                to.request = false;
                to.serializableId = from.serializableId;
                Response response = new Response();
                try {
                    String serverId = request.getClassName() + request.getMethodDesc();
                    ServiceBean serverBean = register.get(serverId);
                    if (serverBean == null) {
                        throw new IllegalArgumentException("找不到服务" + serverId);
                    }
                    Object result = serverBean.invoke(request.getArgs());
                    response.setResult(result);
                    to.status = Transfer.STATUS_OK;
                } catch (Throwable e) {
                    e.printStackTrace();
                    response.setError(e);
                    to.status = Transfer.STATUS_ERROR;
                }
                to.target = response;
                return to;
            }
        }

        public void registerServer(Class serviceInterface, Object serverBean) {
            assert serviceInterface.isInterface();
            for (Method method : serviceInterface.getMethods()) {
                int modifiers = method.getModifiers();
                if (Modifier.isStatic(modifiers) || Modifier.isNative(modifiers)) {
                    continue;
                }
                String methodDescriptor = Type.getMethodDescriptor(method);
                String key = serviceInterface.getName() + method.getName() + methodDescriptor;
                register.put(key, new ServiceBean(method, serverBean));
            }
        }
    }

    private static class ServiceBean {
        Method method;
        Object target;

        public ServiceBean(Method method, Object target) {
            this.method = method;
            this.target = target;
        }

        public Object invoke(Object[] args) throws Exception {
            return method.invoke(target, args);
        }
    }


    public static class RpcCodec extends ByteToMessageCodec {
        protected static final int HEADER_LENGTH = 16;
        protected static final short MAGIC = 0xdad;
        protected static final ByteBuf MAGIC_BUF = Unpooled.copyShort(MAGIC);
        protected static final byte FLAG_REQUEST = (byte) 0x80;//1000 0000
        protected static final byte FLAG_TWO_WAY = (byte) 0x40; //0100 0000
        protected static final byte FLAG_EVENT = (byte) 0x20;  //0010 0000
        protected static final int SERIALIZATION_MASK = 0x1f;  //0001 1111

        // 编码
        @Override
        protected void encode(ChannelHandlerContext ctx, Object msg, ByteBuf out) throws Exception {
            if (msg instanceof Transfer) {
                doEncode((Transfer) msg, out);
            } else {
                throw new IllegalArgumentException();
            }
        }

        //解码
        @Override
        protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception {
            Transfer transfer = doDecode(in);
            if (transfer != null) {
                out.add(transfer);
            }
        }

        // 编码
        protected void doEncode(Transfer data, ByteBuf buf) {
            byte[] header = new byte[HEADER_LENGTH];
            Bytes.short2bytes(MAGIC, header);

            header[2] = data.serializableId;
            if (data.request) header[2] |= FLAG_REQUEST;
            if (data.twoWay) header[2] |= FLAG_TWO_WAY;
            if (data.heartbeat) header[2] |= FLAG_EVENT;
            if (!data.request) header[3] = data.status;

            Bytes.long2bytes(data.id, header, 4);// id 占8个字节
            int len = 0;
            byte[] body = new byte[0];
            if (!data.heartbeat) {
                body = serialize(data.serializableId, data.target);
                len = body.length;
            }
            Bytes.int2bytes(len, header, 12);
            buf.writeBytes(header);
            buf.writeBytes(body);
        }

        // 解码
        protected Transfer doDecode(ByteBuf in) {
            int index = ByteBufUtil.indexOf(in, MAGIC_BUF);
            if (index < 0) {
                return null; //需要更多的字节
            }
            if (!in.isReadable(index + HEADER_LENGTH)) {
                return null;//需要更多的字节
            }
            byte[] header = new byte[HEADER_LENGTH];
//      in.getBytes(index, header);
            ByteBuf slice = in.slice();
            slice.readBytes(header);
            int length = Bytes.bytes2int(header, 12);

            if (!in.isReadable(index + HEADER_LENGTH + length)) {
                return null;//需要更多的字节
            }
            Transfer transfer = new Transfer(Bytes.bytes2long(header, 4));
            transfer.heartbeat = (header[2] & FLAG_EVENT) != 0;
            transfer.request = (header[2] & FLAG_REQUEST) != 0;
            transfer.twoWay = (header[2] & FLAG_TWO_WAY) != 0;
            transfer.serializableId = (byte) (header[2] & SERIALIZATION_MASK);
            transfer.status = header[3];
            if (!transfer.heartbeat) {
                byte content[] = new byte[length];
//          in.getBytes(index + HEADER_LENGTH, bytes);
                slice.readBytes(content);
                transfer.target= deserialize(transfer.serializableId, content);
            }
            in.skipBytes(index + HEADER_LENGTH + length);

            return transfer;
        }

        // 序列化
        private byte[] serialize(byte serializableId, Object target) {

            if (serializableId == Transfer.SERIALIZABLE_JAVA) { //JAVA
                ByteArrayOutputStream out = null;
                try {
                    out = new ByteArrayOutputStream();
                    ObjectOutputStream stream = new ObjectOutputStream(out);
                    stream.writeObject(target);
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
                return out.toByteArray();
            } else {
                throw new UnsupportedOperationException();
            }
        }

        // 反序列化
        private Object deserialize(byte serializableId, byte[] bytes) {
            if (serializableId == Transfer.SERIALIZABLE_JAVA) { //JAVA
                try {
                    ObjectInputStream stream =
                            new ObjectInputStream(new ByteArrayInputStream(bytes));
                    return stream.readObject();
                } catch (IOException | ClassNotFoundException e) {
                    throw new RuntimeException(e);
                }
            } else {
                throw new UnsupportedOperationException();
            }
        }
    }

    public static class RpcClient {
        static AtomicLong atomicLong = new AtomicLong(100);
        private Channel channel;
        private Map<Long, Promise<Response>> results = new HashMap<>();

        public static long getNextId() {
            return atomicLong.getAndIncrement();
        }

        public void init(String address, int port) throws InterruptedException {
            Bootstrap bootstrap = new Bootstrap();
            bootstrap.group(new NioEventLoopGroup(1))
                    .channel(NioSocketChannel.class);
            bootstrap.handler(new ChannelInitializer<Channel>() {
                @Override
                protected void initChannel(Channel ch) {
                    ch.pipeline().addLast("codec", new RpcCodec());
                    ch.pipeline().addLast("resultSet", new ResultFill());// 结果集填充
                }
            });
            ChannelFuture connect = bootstrap.connect(address, port);
            channel = connect.sync().channel();
            System.out.println("连接成功");
            //
            // 每隔 两秒发送心跳
            channel.eventLoop().scheduleWithFixedDelay(() -> {
                Transfer transfer=new Transfer(getNextId());
                transfer.heartbeat=true;
                channel.writeAndFlush(transfer);
            },2000,2000, TimeUnit.MILLISECONDS);
        }

        public Response invokerRemote(Class serverInterface,
                                      String methodDesc,
                                      Object[] args) throws InterruptedException, ExecutionException, TimeoutException {
            Request request = new Request(serverInterface.getName(), methodDesc);
            request.setArgs(args);
            Transfer transfer = new Transfer(getNextId());
            transfer.request=true;
            transfer.serializableId=Transfer.SERIALIZABLE_JAVA;
            transfer.target = request;
            DefaultPromise<Response> resultPromise = new DefaultPromise(channel.eventLoop());
            // 写入成功后添加 结果
            channel.writeAndFlush(transfer).addListener(future ->
                    {// IO线程
                        if (future.cause() != null) {// 写入失败
                            resultPromise.setFailure(future.cause()); //写入失败必须处理
                        } else {    // 写入成功
                            results.put(transfer.id, resultPromise);
                        }
                    }
            );

            return resultPromise.get(10000, TimeUnit.MILLISECONDS);
        }

        private class ResultFill extends SimpleChannelInboundHandler<Transfer> {
            @Override
            protected void channelRead0(ChannelHandlerContext ctx, Transfer msg) {
                if (msg.heartbeat) {
                    System.out.println(String.format("服务端心跳返回:%s",
                            ctx.channel().remoteAddress()));
                } else {
                    Promise<Response> promise = results.remove(msg.id);
                    promise.setSuccess((Response) msg.target); // 填充结果
                }
            }
        }

        public <T> T getRemoteService(Class<T> serviceInterface) {
            assert serviceInterface.isInterface();
            Object o = Proxy.newProxyInstance(getClass().getClassLoader(), new Class[]{serviceInterface}, new InvocationHandler() {
                @Override
                public Object invoke(Object proxy, Method method, Object[] args) throws Exception {
                    if (Object.class.equals(method.getDeclaringClass())) {
                        return method.invoke(this, args);
                    }

                    String methodDescriptor = method.getName()+ Type.getMethodDescriptor(method);
                    Response response = invokerRemote(serviceInterface, methodDescriptor, args);
                    if (response.getError() != null) {
                        throw new RuntimeException("远程服务调用异常:", response.getError());
                    }
                    return response.getResult();
                }
            });
            return (T) o;
        }


    }

    public static class Transfer {
        public static final byte STATUS_ERROR = 0;
        public static final byte STATUS_OK = 1;
        public static final byte STATUS_ILLEGAL = 2;
        public static final byte SERIALIZABLE_JAVA=1;
        public static final byte SERIALIZABLE_HESSIAN2=2;
        public static final byte SERIALIZABLE_JSON=3;

        boolean request;
        byte serializableId; // 1:java 2:hessian2 3:json
        boolean twoWay;
        boolean heartbeat;
        long id;
        byte status;    // 1正常 0失败 2请求非法
        Object target;

        public Transfer(long id) {
            this.id = id;
        }

        void copy(Transfer from) {
            this.request = from.request;
            this.serializableId = from.serializableId;
            this.twoWay = from.twoWay;
            this.heartbeat = from.heartbeat;
            this.id = from.id;
            this.status = from.status;
            this.target = from.target;
        }
    }

    public static class Request   implements java.io.Serializable {
        private String methodDesc;
        private String className;
        private Object args[];


        public Request(String className,String methodDesc) {
            this.className=className;
            this.methodDesc=methodDesc;
        }

        public String getMethodDesc() {
            return methodDesc;
        }

        public void setMethodDesc(String methodDesc) {
            this.methodDesc = methodDesc;
        }

        public String getClassName() {
            return className;
        }

        public void setClassName(String className) {
            this.className = className;
        }

        public Object[] getArgs() {
            return args;
        }

        public void setArgs(Object[] args) {
            this.args = args;
        }

    }

    public static class Response   implements java.io.Serializable {

        Object result;
        Throwable error;


        public Object getResult() {
            return result;
        }

        public void setResult(Object result) {
            this.result = result;
        }

        public Throwable getError() {
            return error;
        }

        public void setError(Throwable error) {
            this.error = error;
        }

    }
}

转载请注明来源。 欢迎对文章中的引用来源进行考证,欢迎指出任何有错误或不够清晰的表达。 可以在下面评论区评论,也可以邮件至 sharlot2050@foxmail.com。

文章标题:基于 Netty 实现 RPC

字数:2.4k

本文作者:夏来风

发布时间:2020-08-02, 23:58:34

原始链接:http://www.demo1024.com/blog/netty-RPC/

版权声明: "署名-非商用-相同方式共享 4.0" 转载请保留原文链接及作者。