C语言之浅析网络包解析

1.这几天研究skynet中的 lua-netpack.c 中的解析数据包过程。于是把lua部分去掉,修改了一些接口,留下解包相关的代码。再结合云风写的网络代码的例子,

   写了一个最简单形式的客户端封包,服务器解包的代码,作为学习笔记的同时也希望能够帮助一些像我一样的新手学习理解封包,解包的概念。

 ps:修改的代码实现了,当收到一个整包时,打印整包内容的功能,但是并没有从完整的包队列中pop完整包的接口,可自行加上。

2.服务器和客户端程序简介:

  server:server端实现的很简单,基于云风给出的skynet网络层代码的例子。在收到客户端发来数据以后,加上了解包过程。

      1):生成。将代码放到一个文件夹下以后,直接运行make便可得到 socket-server, 执行 ./socket-server 运行程序。

      2):可自行在解包处 filter_data_里加上log来具体分析解包过程,ps:我就是这样搞的。

      3):理解流程以后,可自行实现修改 message_resolve.c和test.c来实现更详细的测试

      client: 该客户端比较简单,只是简单的建立连接,然后输入发送消息的条数(n),然后会自动随机生成n条长度为 1到2^16次方的随机消息(因为skynet的消息长度字段是两个字节)

         ps:之前是手动输入消息,但是消息太短会导致服务器解包过程中的一些条件覆盖不到,于是只需输入消息条数,让程序自动生成随机长度消息(实际上只是字符'a'的序列,可以自行修改

       生成消息的函数 gen_msg 来生成其它内容的消息)

      1):client只有一个文件socket-cient.c 直接通过命令行 gcc -o socket-client socket-client.c 来获得客户端的可执行程序 socket-client,

        运行./socket-client 来执行客户端程序。

      2):修改 #define MSG_LEN 来更改随机消息长度范围,可设置小一些。

3.解包过程浅析:

  1客户端消息格式: 2字节消息长度(大端字节序)

  服务器解包过程中的几种情况:

  1)首次收到消息包的时候,可能之后到一个字节,因为每个包开头两字节组成包长字段,所以这种情况要将其放入为接受完队列。待下次从该fd中读取剩余消息;

  2)当收到的是整个包的时候,直接获得完整的消息;

  3)当收到的是多余一个包时,将获得的完整包放入完整的消息包队列,然后将剩余的消息放入未完成消息队列中,下次接收到消息时,找到该fd的未完成消息队列,

    继续解析。

4.代码:结构如下

Makefile:

socket-server : socket_server.c  message_resolve.c  test.c
    gcc -g -Wall -o $@ $^ -lpthread -I/message_resolve.h

clean:
    rm -f socket-server

message_resolve.c:

#include "skynet_malloc.h"
#include "message_resolve.h"
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>

#define QUEUESIZE 1024
#define HASHSIZE 4096
#define SMALLSTRING 2048

#define TYPE_DATA 1
#define TYPE_MORE 2
#define TYPE_ERROR 3
#define TYPE_OPEN 4
#define TYPE_CLOSE 5
#define TYPE_WARNING 6


/*
    Each package is uint16 + data , uint16 (serialized in big-endian) is the number of bytes comprising the data .
*/

struct netpack {
    int id;
    int size;
    void * buffer;
};

struct uncomplete {
    struct netpack pack;
    struct uncomplete * next;
    int read;
    int header;
};

struct queue {
    int cap;
    int head;
    int tail;
    struct uncomplete * hash[HASHSIZE];
    struct netpack queue[QUEUESIZE];
};


static struct queue *Q;

void
init_queue()
{
    struct queue *q = (struct queue *)skynet_malloc(sizeof(*q));
    q->head = 0;
    q->tail = 0;
    q->cap = HASHSIZE;

    int i;
    for (i = 0; i < HASHSIZE; i++)
    {
        q->hash[i] = NULL;
    }

    Q = q;
}

/*static void
clear_list(struct uncomplete * uc) {
    while (uc) 
    {
            skynet_free(uc->pack.buffer);
            void * tmp = uc;
            uc = uc->next;
            skynet_free(tmp);
    }
}*/

/*static int
clear(struct queue *q)
{
    if (q == NULL)
    {
        return 0;
    }
    int i;
    for (i = 0; i < HASHSIZE; i++)
    {
        clear_list(q->hash[i]);
        q->hash[i] = NULL;
    }
    if(q->head > q->tail)
    {
        q->tail += q->cap;
    }
    for (i = q->head; i < q->tail; i++)
    {
        struct netpack *np = &q->queue[i % q->cap];
        skynet_free(np->buffer);
        np->buffer = NULL;
    }
    q->head = q->tail = 0;

    return 0;
}*/

static inline int
hash_fd(int fd)
{
        int a = fd >> 24;
        int b = fd >> 12;
        int c = fd;
        return (int)(((uint32_t)(a+b+c))) % HASHSIZE;
}

static inline int
read_size(uint8_t *buffer)
{
        int r = (int)buffer[0] << 8 | (int)buffer[1];
        return r;
}

static struct uncomplete *
save_uncomplete(struct queue *q, int fd)
{
        int h = hash_fd(fd);
        struct uncomplete *uc = (struct uncomplete *)skynet_malloc(sizeof(struct uncomplete));
        memset(uc, 0, sizeof(*uc));
        uc->next = q->hash[h];
        uc->pack.id = fd;
        q->hash[h] = uc;

        return uc;
}

static struct uncomplete * 
find_uncomplete(struct queue *q, int fd)
{
        if (q == NULL)
        {
                return NULL;
        }
        int h = hash_fd(fd);
        struct uncomplete *uc = q->hash[h];
        if (uc == NULL)
        {
                return NULL;
        }
        if (uc->pack.id == fd)
        {
                q->hash[h] = uc->next;
                return uc;
        }

        //hash冲突,可能不同fd对应同一个slot,根据id == fd 区分
        struct uncomplete *last = uc;
        while(last->next)
        {
                uc = last->next;
                if(uc->pack.id == fd)
                {
                        last->next = uc->next;
                        return uc;
                }
                last = uc;
        }

        return NULL;
}

static inline void
expand_queue(struct queue *q)
{
        struct queue *nq = (struct queue *)skynet_malloc(sizeof(struct queue) + sizeof(struct netpack) * q->cap);
        nq->cap = QUEUESIZE + q->cap;
        nq->head = 0;
        nq->tail = q->cap;
        memcpy(nq->hash, q->hash, sizeof(nq->hash));
        memset(q->hash, 0, sizeof(q->hash));

        //之前有疑惑,为什么queue 不能用memcpy直接拷贝,原因是,要保证新队列的值是从旧队列的head开始的,但是旧队列的head并不一定是下标==0得到完整的一个包,并push到队列
        int i;
        for (i = 0; i < q->cap; i++)
        {
                int idx = (q->head + i) % q->cap;
                nq->queue[i] = q->queue[idx];
        }
        q->head = q->tail = 0;
        Q = nq;
}

//得到完整的一个包,并push到队列
static inline void
push_data(struct queue *q, int fd, uint8_t *buffer, int size, int clone)
{
        if(clone)
        {
                void *tmp = (void *)skynet_malloc(size);
                memcpy(tmp, buffer, size);
                buffer = tmp;
        }
        struct netpack *np = &q->queue[q->tail];
        if(++q->tail >= q->cap)
        {
                q->tail -= q->cap;
        }
        np->id = fd;
        np->size = size;
        np->buffer = buffer;
        if(q->head == q->tail)
        {
                expand_queue(q);
        }
}

static inline void
push_more(struct queue *q, uint8_t *buffer, int size, int fd)
{
    if (size == 1)
    {
            struct uncomplete *uc = save_uncomplete(q, fd);
            uc->header = *buffer;
            uc->read = -1;
            return;
    }
    int pack_size = read_size(buffer);
    buffer += 2;
    size -= 2;
    if (size < pack_size)
    {
            struct uncomplete *uc = save_uncomplete(q, fd);
            uc->pack.size = pack_size;
            uc->read = size;
            uc->pack.buffer = (void *)skynet_malloc(pack_size);
            memcpy(uc->pack.buffer, buffer, size);
            return;
    }
    push_data(q, fd, buffer, pack_size, 1);

    buffer += pack_size;
    size -= pack_size;
    if (size > 0)
    {
            push_more(q, buffer, size, fd);
    }
}

/*static void
close_uncomplete(struct queue *q, int fd)
{
    struct uncomplete *uc = find_uncomplete(q, fd);
    if (uc)
    {
            skynet_free(uc->pack.buffer);
            skynet_free(uc);
    }
}*/

int
filter_data_(uint8_t *buffer, int size, int fd, struct my_data *md)
{
        printf("size[%d], fd[%d]
", size, fd);
        struct queue *q = Q;
        struct uncomplete *uc = find_uncomplete(q, fd);
        if (uc)
        {                
                if (uc->read < 0)
                {        
                        assert(uc->read == -1);
                        
                        int pack_size = *buffer;
                        pack_size |= uc->header << 8;
                        ++buffer;
                        --size;
                        uc->pack.size = pack_size;
                        uc->pack.buffer = (void *)skynet_malloc(pack_size);
                        uc->read = 0;
                }        
                int need = uc->pack.size - uc->read;
                printf("need[%d];pack.size[%d];uc->read[%d]
", need , uc->pack.size ,uc->read);
                if (size < need)
                {        
                        memcpy(uc->pack.buffer + uc->read, buffer, size);    
                        uc->read += size;
                        int h = hash_fd(fd);
                        uc->next = q->hash[h];
                        q->hash[h] = uc;
                        return 1;
                }        
                memcpy(uc->pack.buffer + uc->read, buffer, need);
                buffer += need;
                size -= need;
                if (size == 0)
                {
                        //TODO
                        md->size = uc->pack.size;
                        md->fd = fd;
                        md->data = (void *)skynet_malloc(md->size);
                        memcpy(md->data, uc->pack.buffer, uc->pack.size);
                        skynet_free(uc);
                        return 0;
                }
                push_data(q, fd, uc->pack.buffer, uc->pack.size, 0);
                skynet_free(uc);
                push_more(q, buffer, size, fd);
                return 2;
        }
        else
        {
                if(size == 1)
                {
                        struct uncomplete *uc = save_uncomplete(q, fd);
                        uc->read = -1;
                        uc->header = *buffer;
                        return 1;
                }
                int pack_size = read_size(buffer);
                buffer += 2;
                size -= 2;
                //printf("pack_size[%d], size[%d], buffer[%s]
", pack_size, size, buffer);
                if (size < pack_size)
                {
                        printf("size < pack_size
");
                        struct uncomplete *uc = save_uncomplete(q, fd);
                        uc->read = size;
                        uc->pack.size = pack_size;
                        uc->pack.buffer = (void *)skynet_malloc(pack_size);
                        memcpy(uc->pack.buffer, buffer, size);
                        return 1;
                }

                if(size == pack_size)
                {
                        printf("size == packsize
");
                        //TODO
                        md->size = size;
                        md->fd = fd;
                        md->data = (void *)skynet_malloc(size);
                        memcpy(md->data, buffer, size);
                        return 0;
                }            

                push_data(q, fd, buffer, pack_size, 1);
                buffer += pack_size;
                size -= pack_size;    
                push_more(q, buffer, size, fd);
                return 2;
        }
}
View Code

message_resolve.h

#ifndef __MESSAGE_RESOLVE_H__
#define __MESSAGE_RESOLVE_H__
#include <stdint.h>

struct my_data
{
    int size;
    int fd;
    void *data;
};

void
init_queue();
int
filter_data_(uint8_t *buffer, int size, int fd, struct my_data *md);

#endif
View Code

skynet_malloc.h

#ifndef skynet_malloc_h
#define skynet_malloc_h

#include <stddef.h>

#define skynet_malloc malloc
#define skynet_calloc calloc
#define skynet_realloc realloc
#define skynet_free free

void * skynet_malloc(size_t sz);
void * skynet_calloc(size_t nmemb,size_t size);
void * skynet_realloc(void *ptr, size_t size);
void skynet_free(void *ptr);
char * skynet_strdup(const char *str);
void * skynet_lalloc(void *ptr, size_t osize, size_t nsize);    // use for lua

#endif
View Code

socket-client.c

#include <stdio.h>
#include <stdlib.h>
#include <netdb.h>
#include <string.h>
#include <stdint.h>
#include <errno.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <unistd.h>
#include <fcntl.h>
#include <assert.h>

#define IP "127.0.0.1"
#define PORT "8888"
#define BUFF_SZ 1024*1024
#define MSG_LEN 65535
#define SEND_TIMES 1000

static void
sp_nonblocking(int fd) 
{
    int flag = fcntl(fd, F_GETFL, 0);
    if ( -1 == flag ) 
    {
        return;
    }

    fcntl(fd, F_SETFL, flag | O_NONBLOCK);
}

static void
socket_keepalive(int fd)
{
    int keepalive = 1;
    setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, (void *)&keepalive , sizeof(keepalive));
}

static int
open_socket() 
{
    struct addrinfo ai_hints;
    struct addrinfo *ai_list = NULL;
    struct addrinfo *ai_ptr = NULL;
    memset(&ai_hints, 0, sizeof( ai_hints ) );
    ai_hints.ai_family = AF_UNSPEC;
    ai_hints.ai_socktype = SOCK_STREAM;
    ai_hints.ai_protocol = IPPROTO_TCP;
    int status = 0;
    status = getaddrinfo(IP, PORT, &ai_hints, &ai_list );
    if ( status != 0 ) 
    {
        goto _failed;
    }
    int sock= -1;
    for (ai_ptr = ai_list; ai_ptr != NULL; ai_ptr = ai_ptr->ai_next ) {
        sock = socket( ai_ptr->ai_family, ai_ptr->ai_socktype, ai_ptr->ai_protocol );
        if ( sock < 0 ) 
        {
            continue;
        }
        socket_keepalive(sock);
        //sp_nonblocking(sock);
        status = connect( sock, ai_ptr->ai_addr, ai_ptr->ai_addrlen);
        if ( status != 0 && errno != EINPROGRESS) 
        {
            fprintf(stderr, "status not 0[%d]
", status);
            close(sock);
            sock = -1;
            continue;
        }
        break;
    }

    if (sock < 0) 
    {
        goto _failed;
    }

    if(status == 0) 
    {
        freeaddrinfo( ai_list );
        return sock;
    } 
    else 
    {
        fprintf(stderr, "connect failed
");
        exit(EXIT_FAILURE);
    }

    freeaddrinfo( ai_list );
    return -1;
_failed:
    freeaddrinfo( ai_list );
    
    return -1;
}

static void
pack(char *msg, int len)
{
    if (msg == NULL || len < 0)
    {
        fprintf(stderr, "wrong param
");
        exit(0);
    }
    //char buff[BUFF_SZ + 2];
    //memset(buff, 0, sizeof(buff));

    //int len = strlen(msg) + 1;
    printf("in pack len is [%d]
", len);
    int pack_size;
    memset(&pack_size, 0, sizeof(int));
    pack_size |= ((len << 8) & 0x0000ff00);
    pack_size |= ((len >> 8) & 0xff);
    memcpy(msg, &pack_size, 2);
    //snprintf(buff+2, sizeof(buff)-2, "%s", msg);
    //memcpy(msg, buff, sizeof(buff));    
    //printf("pack onver msg[%d]
", strlen(msg));
}

static int
gen_msg(char *newmsg)
{
    if(newmsg == NULL)
    {
        fprintf(stderr, "newmsg == nil
");
        exit(0);
    }
    int msglen = (rand() % (MSG_LEN-1))+ 1 + 1;
    printf("msglen is: %d
", msglen);

    int i;
    for (i = 0; i < msglen; i++)
    {
        newmsg[2+i] = 'a';  //reserve first two bytes for len
    }

    return msglen;
}

static int
send_msg(int fd, char *msg, int size)
{
    if (msg == NULL)
    {
        fprintf(stderr, "msg == NULL
");
        return 0;
    }
    printf("in send_msg...........
");
    int sent = 0;
    int left = size;
    while(left > 0)
    {
        int n = send(fd, msg+sent, left, 0);
        if (n < 0)
        {
            if (errno == EINTR)
            {
                continue;
            }
            else
            {
                int saved_errno = errno;
                fprintf(stderr, "send failed error[%s]
", strerror(saved_errno));
                return 0;
            }
        }
        left -= n;
        sent += n;
        //printf("left, sent [%d], [%d], fd[%d]
", left, sent, fd);
    }
    assert(sent == size);
    printf("send over in sendmsg size[%d]
", size);
    return sent;
}


static void
do_sth(int fd)
{
    char send_buff[BUFF_SZ+2];
    int num = 0;
    while(fflush(stdin), printf("enter msg num:
"), (scanf("%d", &num) == 1))
    {
        int i = 0;
        int len = 0;
        for (i = 0; i < num; i++)
        {
            len = gen_msg(send_buff);
            pack(send_buff, len);
            send_msg(fd, send_buff, len + 2);
            memset(send_buff, 0, sizeof(send_buff));
        }
    }
/*
    int n;
    while(((n = read(0, send_buff, BUFF_SZ)) > 0))
    {
        send_buff[n] = '';
        printf("in enter data [%s], size[%d]
", send_buff, strlen(send_buff) + 1);
        pack(send_buff);
        send_msg(fd, send_buff, n+2+1);
    }*/
}


int
main(void)
{
    srand((unsigned int)(time(NULL)));
    char buff[BUFF_SZ];
    memset(buff, 0, sizeof(buff));
    int sock = open_socket();
    if(sock < 0)
    {
        fprintf(stderr, "open_socket failed
");
        exit(0);
    }
    else
    {
        do_sth(sock);
    }
    printf("close is called
");
    close(sock);
    return 0;
}
View Code

socket_epoll.h

#ifndef poll_socket_epoll_h
#define poll_socket_epoll_h

#include <netdb.h>
#include <unistd.h>
#include <sys/epoll.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <fcntl.h>

static bool 
sp_invalid(int efd) {
    return efd == -1;
}

static int
sp_create() {
    return epoll_create(1024);
}

static void
sp_release(int efd) {
    close(efd);
}

static int 
sp_add(int efd, int sock, void *ud) {
    struct epoll_event ev;
    ev.events = EPOLLIN;
    ev.data.ptr = ud;
    if (epoll_ctl(efd, EPOLL_CTL_ADD, sock, &ev) == -1) {
        return 1;
    }
    return 0;
}

static void 
sp_del(int efd, int sock) {
    epoll_ctl(efd, EPOLL_CTL_DEL, sock , NULL);
}

static void 
sp_write(int efd, int sock, void *ud, bool enable) {
    struct epoll_event ev;
    ev.events = EPOLLIN | (enable ? EPOLLOUT : 0);
    ev.data.ptr = ud;
    epoll_ctl(efd, EPOLL_CTL_MOD, sock, &ev);
}

static int 
sp_wait(int efd, struct event *e, int max) {
    struct epoll_event ev[max];
    int n = epoll_wait(efd , ev, max, -1);
    int i;
    for (i=0;i<n;i++) {
        e[i].s = ev[i].data.ptr;
        unsigned flag = ev[i].events;
        e[i].write = (flag & EPOLLOUT) != 0;
        e[i].read = (flag & EPOLLIN) != 0;
    }

    return n;
}

static void
sp_nonblocking(int fd) {
    int flag = fcntl(fd, F_GETFL, 0);
    if ( -1 == flag ) {
        return;
    }

    fcntl(fd, F_SETFL, flag | O_NONBLOCK);
}

#endif
View Code

socket_kqueue.h

#ifndef poll_socket_kqueue_h
#define poll_socket_kqueue_h

#include <netdb.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/event.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>

static bool 
sp_invalid(int kfd) {
    return kfd == -1;
}

static int
sp_create() {
    return kqueue();
}

static void
sp_release(int kfd) {
    close(kfd);
}

static void 
sp_del(int kfd, int sock) {
    struct kevent ke;
    EV_SET(&ke, sock, EVFILT_READ, EV_DELETE, 0, 0, NULL);
    kevent(kfd, &ke, 1, NULL, 0, NULL);
    EV_SET(&ke, sock, EVFILT_WRITE, EV_DELETE, 0, 0, NULL);
    kevent(kfd, &ke, 1, NULL, 0, NULL);
}

static int 
sp_add(int kfd, int sock, void *ud) {
    struct kevent ke;
    EV_SET(&ke, sock, EVFILT_READ, EV_ADD, 0, 0, ud);
    if (kevent(kfd, &ke, 1, NULL, 0, NULL) == -1) {
        return 1;
    }
    EV_SET(&ke, sock, EVFILT_WRITE, EV_ADD, 0, 0, ud);
    if (kevent(kfd, &ke, 1, NULL, 0, NULL) == -1) {
        EV_SET(&ke, sock, EVFILT_READ, EV_DELETE, 0, 0, NULL);
        kevent(kfd, &ke, 1, NULL, 0, NULL);
        return 1;
    }
    EV_SET(&ke, sock, EVFILT_WRITE, EV_DISABLE, 0, 0, ud);
    if (kevent(kfd, &ke, 1, NULL, 0, NULL) == -1) {
        sp_del(kfd, sock);
        return 1;
    }
    return 0;
}

static void 
sp_write(int kfd, int sock, void *ud, bool enable) {
    struct kevent ke;
    EV_SET(&ke, sock, EVFILT_WRITE, enable ? EV_ENABLE : EV_DISABLE, 0, 0, ud);
    if (kevent(kfd, &ke, 1, NULL, 0, NULL) == -1) {
        // todo: check error
    }
}

static int 
sp_wait(int kfd, struct event *e, int max) {
    struct kevent ev[max];
    int n = kevent(kfd, NULL, 0, ev, max, NULL);

    int i;
    for (i=0;i<n;i++) {
        e[i].s = ev[i].udata;
        unsigned filter = ev[i].filter;
        e[i].write = (filter == EVFILT_WRITE);
        e[i].read = (filter == EVFILT_READ);
    }

    return n;
}

static void
sp_nonblocking(int fd) {
    int flag = fcntl(fd, F_GETFL, 0);
    if ( -1 == flag ) {
        return;
    }

    fcntl(fd, F_SETFL, flag | O_NONBLOCK);
}

#endif
View Code

socket_poll.h

#ifndef socket_poll_h
#define socket_poll_h

#include <stdbool.h>

typedef int poll_fd;

struct event {
    void * s;
    bool read;
    bool write;
};

static bool sp_invalid(poll_fd fd);
static poll_fd sp_create();
static void sp_release(poll_fd fd);
static int sp_add(poll_fd fd, int sock, void *ud);
static void sp_del(poll_fd fd, int sock);
static void sp_write(poll_fd, int sock, void *ud, bool enable);
static int sp_wait(poll_fd, struct event *e, int max);
static void sp_nonblocking(int sock);

#ifdef __linux__
#include "socket_epoll.h"
#endif

#if defined(__APPLE__) || defined(__FreeBSD__) || defined(__OpenBSD__) || defined (__NetBSD__)
#include "socket_kqueue.h"
#endif

#endif
View Code

socket_server.c

#ifdef SOCKET_SERVER_FILE_MEMAPI
#   define  STRINIFY_(S)    #S
#   define  STRINIFY(S)     STRINIFY_(S)
#   include STRINIFY(SOCKET_SERVER_FILE_MEMAPI)
#   undef   STRINIFY
#   undef   STRINIFY_
#endif

#include "socket_server.h"
#include "socket_poll.h"

#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/tcp.h>
#include <unistd.h>
#include <errno.h>
#include <stddef.h>
#include <stdlib.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdint.h>
#include <assert.h>
#include <string.h>

#define MAX_INFO 128
// MAX_SOCKET will be 2^MAX_SOCKET_P
#define MAX_SOCKET_P 16
#define MAX_EVENT 64
#define MIN_READ_BUFFER 64
#define SOCKET_TYPE_INVALID 0
#define SOCKET_TYPE_RESERVE 1
#define SOCKET_TYPE_PLISTEN 2
#define SOCKET_TYPE_LISTEN 3
#define SOCKET_TYPE_CONNECTING 4
#define SOCKET_TYPE_CONNECTED 5
#define SOCKET_TYPE_HALFCLOSE 6
#define SOCKET_TYPE_PACCEPT 7
#define SOCKET_TYPE_BIND 8

#define MAX_SOCKET (1<<MAX_SOCKET_P)

#define PRIORITY_HIGH 0
#define PRIORITY_LOW 1

#define HASH_ID(id) (((unsigned)id) % MAX_SOCKET)

#define PROTOCOL_TCP 0
#define PROTOCOL_UDP 1
#define PROTOCOL_UDPv6 2

#define UDP_ADDRESS_SIZE 19    // ipv6 128bit + port 16bit + 1 byte type

#define MAX_UDP_PACKAGE 65535

struct write_buffer {
    struct write_buffer * next;
    void *buffer;
    char *ptr;
    int sz;
    bool userobject;
    uint8_t udp_address[UDP_ADDRESS_SIZE];
};

#define SIZEOF_TCPBUFFER (offsetof(struct write_buffer, udp_address[0]))
#define SIZEOF_UDPBUFFER (sizeof(struct write_buffer))

struct wb_list {
    struct write_buffer * head;
    struct write_buffer * tail;
};

struct socket {
    uintptr_t opaque;
    struct wb_list high;
    struct wb_list low;
    int64_t wb_size;
    int fd;
    int id;
    uint16_t protocol;
    uint16_t type;
    union {
        int size;
        uint8_t udp_address[UDP_ADDRESS_SIZE];
    } p;
};

struct socket_server {
    int recvctrl_fd;
    int sendctrl_fd;
    int checkctrl;
    poll_fd event_fd;
    int alloc_id;
    int event_n;
    int event_index;
    struct socket_object_interface soi;
    struct event ev[MAX_EVENT];
    struct socket slot[MAX_SOCKET];
    char buffer[MAX_INFO];
    uint8_t udpbuffer[MAX_UDP_PACKAGE];
    fd_set rfds;
};

struct request_open {
    int id;
    int port;
    uintptr_t opaque;
    char host[1];
};

struct request_send {
    int id;
    int sz;
    char * buffer;
};

struct request_send_udp {
    struct request_send send;
    uint8_t address[UDP_ADDRESS_SIZE];
};

struct request_setudp {
    int id;
    uint8_t address[UDP_ADDRESS_SIZE];
};

struct request_close {
    int id;
    uintptr_t opaque;
};

struct request_listen {
    int id;
    int fd;
    uintptr_t opaque;
    char host[1];
};

struct request_bind {
    int id;
    int fd;
    uintptr_t opaque;
};

struct request_start {
    int id;
    uintptr_t opaque;
};

struct request_setopt {
    int id;
    int what;
    int value;
};

struct request_udp {
    int id;
    int fd;
    int family;
    uintptr_t opaque;
};

/*
    The first byte is TYPE

    S Start socket
    B Bind socket
    L Listen socket
    K Close socket
    O Connect to (Open)
    X Exit
    D Send package (high)
    P Send package (low)
    A Send UDP package
    T Set opt
    U Create UDP socket
    C set udp address
 */

struct request_package {
    uint8_t header[8];    // 6 bytes dummy
    union {
        char buffer[256];
        struct request_open open;
        struct request_send send;
        struct request_send_udp send_udp;
        struct request_close close;
        struct request_listen listen;
        struct request_bind bind;
        struct request_start start;
        struct request_setopt setopt;
        struct request_udp udp;
        struct request_setudp set_udp;
    } u;
    uint8_t dummy[256];
};

union sockaddr_all {
    struct sockaddr s;
    struct sockaddr_in v4;
    struct sockaddr_in6 v6;
};

struct send_object {
    void * buffer;
    int sz;
    void (*free_func)(void *);
};

#ifdef SOCKET_SERVER_MALLOC
#    define MALLOC     SOCKET_SERVER_MALLOC
#else
#    define MALLOC     malloc
#endif

#ifdef SOCKET_SERVER_FREE
#    define FREE     SOCKET_SERVER_FREE
#else
#    define FREE     free
#endif

static inline bool
send_object_init(struct socket_server *ss, struct send_object *so, void *object, int sz) {
    if (sz < 0) {
        so->buffer = ss->soi.buffer(object);
        so->sz = ss->soi.size(object);
        so->free_func = ss->soi.free;
        return true;
    } else {
        so->buffer = object;
        so->sz = sz;
        so->free_func = FREE;
        return false;
    }
}

static inline void
write_buffer_free(struct socket_server *ss, struct write_buffer *wb) {
    if (wb->userobject) {
        ss->soi.free(wb->buffer);
    } else {
        FREE(wb->buffer);
    }
    FREE(wb);
}

static void
socket_keepalive(int fd) {
    int keepalive = 1;
    setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, (void *)&keepalive , sizeof(keepalive));
}

static int
reserve_id(struct socket_server *ss) {
    int i;
    for (i=0;i<MAX_SOCKET;i++) {
        int id = __sync_add_and_fetch(&(ss->alloc_id), 1);
        if (id < 0) {
            id = __sync_and_and_fetch(&(ss->alloc_id), 0x7fffffff);
        }
        struct socket *s = &ss->slot[HASH_ID(id)];
        if (s->type == SOCKET_TYPE_INVALID) {
            if (__sync_bool_compare_and_swap(&s->type, SOCKET_TYPE_INVALID, SOCKET_TYPE_RESERVE)) {
                s->id = id;
                s->fd = -1;
                return id;
            } else {
                // retry
                --i;
            }
        }
    }
    return -1;
}

static inline void
clear_wb_list(struct wb_list *list) {
    list->head = NULL;
    list->tail = NULL;
}

struct socket_server *
socket_server_create() {
    int i;
    int fd[2];
    poll_fd efd = sp_create();
    if (sp_invalid(efd)) {
        fprintf(stderr, "socket-server: create event pool failed.
");
        return NULL;
    }
    if (pipe(fd)) {
        sp_release(efd);
        fprintf(stderr, "socket-server: create socket pair failed.
");
        return NULL;
    }
    if (sp_add(efd, fd[0], NULL)) {
        // add recvctrl_fd to event poll
        fprintf(stderr, "socket-server: can't add server fd to event pool.
");
        close(fd[0]);
        close(fd[1]);
        sp_release(efd);
        return NULL;
    }

    struct socket_server *ss = MALLOC(sizeof(*ss));
    ss->event_fd = efd;
    ss->recvctrl_fd = fd[0];
    ss->sendctrl_fd = fd[1];
    ss->checkctrl = 1;

    for (i=0;i<MAX_SOCKET;i++) {
        struct socket *s = &ss->slot[i];
        s->type = SOCKET_TYPE_INVALID;
        clear_wb_list(&s->high);
        clear_wb_list(&s->low);
    }
    ss->alloc_id = 0;
    ss->event_n = 0;
    ss->event_index = 0;
    memset(&ss->soi, 0, sizeof(ss->soi));
    FD_ZERO(&ss->rfds);
    assert(ss->recvctrl_fd < FD_SETSIZE);

    return ss;
}

static void
free_wb_list(struct socket_server *ss, struct wb_list *list) {
    struct write_buffer *wb = list->head;
    while (wb) {
        struct write_buffer *tmp = wb;
        wb = wb->next;
        write_buffer_free(ss, tmp);
    }
    list->head = NULL;
    list->tail = NULL;
}

static void
force_close(struct socket_server *ss, struct socket *s, struct socket_message *result) {
    result->id = s->id;
    result->ud = 0;
    result->data = NULL;
    result->opaque = s->opaque;
    if (s->type == SOCKET_TYPE_INVALID) {
        return;
    }
    assert(s->type != SOCKET_TYPE_RESERVE);
    free_wb_list(ss,&s->high);
    free_wb_list(ss,&s->low);
    if (s->type != SOCKET_TYPE_PACCEPT && s->type != SOCKET_TYPE_PLISTEN) {
        sp_del(ss->event_fd, s->fd);
    }
    if (s->type != SOCKET_TYPE_BIND) {
        close(s->fd);
    }
    s->type = SOCKET_TYPE_INVALID;
}

void
socket_server_release(struct socket_server *ss) {
    int i;
    struct socket_message dummy;
    for (i=0;i<MAX_SOCKET;i++) {
        struct socket *s = &ss->slot[i];
        if (s->type != SOCKET_TYPE_RESERVE) {
            force_close(ss, s , &dummy);
        }
    }
    close(ss->sendctrl_fd);
    close(ss->recvctrl_fd);
    sp_release(ss->event_fd);
    FREE(ss);
}

static inline void
check_wb_list(struct wb_list *s) {
    assert(s->head == NULL);
    assert(s->tail == NULL);
}

static struct socket *
new_fd(struct socket_server *ss, int id, int fd, int protocol, uintptr_t opaque, bool add) {
    struct socket * s = &ss->slot[HASH_ID(id)];
    assert(s->type == SOCKET_TYPE_RESERVE);

    if (add) {
        if (sp_add(ss->event_fd, fd, s)) {
            s->type = SOCKET_TYPE_INVALID;
            return NULL;
        }
    }

    s->id = id;
    s->fd = fd;
    s->protocol = protocol;
    s->p.size = MIN_READ_BUFFER;
    s->opaque = opaque;
    s->wb_size = 0;
    check_wb_list(&s->high);
    check_wb_list(&s->low);
    return s;
}

// return -1 when connecting
static int
open_socket(struct socket_server *ss, struct request_open * request, struct socket_message *result) {
    int id = request->id;
    result->opaque = request->opaque;
    result->id = id;
    result->ud = 0;
    result->data = NULL;
    struct socket *ns;
    int status;
    struct addrinfo ai_hints;
    struct addrinfo *ai_list = NULL;
    struct addrinfo *ai_ptr = NULL;
    char port[16];
    sprintf(port, "%d", request->port);
    memset(&ai_hints, 0, sizeof( ai_hints ) );
    ai_hints.ai_family = AF_UNSPEC;
    ai_hints.ai_socktype = SOCK_STREAM;
    ai_hints.ai_protocol = IPPROTO_TCP;

    status = getaddrinfo( request->host, port, &ai_hints, &ai_list );
    if ( status != 0 ) {
        goto _failed;
    }
    int sock= -1;
    for (ai_ptr = ai_list; ai_ptr != NULL; ai_ptr = ai_ptr->ai_next ) {
        sock = socket( ai_ptr->ai_family, ai_ptr->ai_socktype, ai_ptr->ai_protocol );
        if ( sock < 0 ) {
            continue;
        }
        socket_keepalive(sock);
        sp_nonblocking(sock);
        status = connect( sock, ai_ptr->ai_addr, ai_ptr->ai_addrlen);
        if ( status != 0 && errno != EINPROGRESS) {
            close(sock);
            sock = -1;
            continue;
        }
        break;
    }

    if (sock < 0) {
        goto _failed;
    }

    ns = new_fd(ss, id, sock, PROTOCOL_TCP, request->opaque, true);
    if (ns == NULL) {
        close(sock);
        goto _failed;
    }

    if(status == 0) {
        ns->type = SOCKET_TYPE_CONNECTED;
        struct sockaddr * addr = ai_ptr->ai_addr;
        void * sin_addr = (ai_ptr->ai_family == AF_INET) ? (void*)&((struct sockaddr_in *)addr)->sin_addr : (void*)&((struct sockaddr_in6 *)addr)->sin6_addr;
        if (inet_ntop(ai_ptr->ai_family, sin_addr, ss->buffer, sizeof(ss->buffer))) {
            result->data = ss->buffer;
        }
        freeaddrinfo( ai_list );
        return SOCKET_OPEN;
    } else {
        ns->type = SOCKET_TYPE_CONNECTING;
        sp_write(ss->event_fd, ns->fd, ns, true);
    }

    freeaddrinfo( ai_list );
    return -1;
_failed:
    freeaddrinfo( ai_list );
    ss->slot[HASH_ID(id)].type = SOCKET_TYPE_INVALID;
    return SOCKET_ERROR;
}

static int
send_list_tcp(struct socket_server *ss, struct socket *s, struct wb_list *list, struct socket_message *result) {
    while (list->head) {
        struct write_buffer * tmp = list->head;
        for (;;) {
            int sz = write(s->fd, tmp->ptr, tmp->sz);
            if (sz < 0) {
                switch(errno) {
                case EINTR:
                    continue;
                case EAGAIN:
                    return -1;
                }
                force_close(ss,s, result);
                return SOCKET_CLOSE;
            }
            s->wb_size -= sz;
            if (sz != tmp->sz) {
                tmp->ptr += sz;
                tmp->sz -= sz;
                return -1;
            }
            break;
        }
        list->head = tmp->next;
        write_buffer_free(ss,tmp);
    }
    list->tail = NULL;

    return -1;
}

static socklen_t
udp_socket_address(struct socket *s, const uint8_t udp_address[UDP_ADDRESS_SIZE], union sockaddr_all *sa) {
    int type = (uint8_t)udp_address[0];
    if (type != s->protocol)
        return 0;
    uint16_t port = 0;
    memcpy(&port, udp_address+1, sizeof(uint16_t));
    switch (s->protocol) {
    case PROTOCOL_UDP:
        memset(&sa->v4, 0, sizeof(sa->v4));
        sa->s.sa_family = AF_INET;
        sa->v4.sin_port = port;
        memcpy(&sa->v4.sin_addr, udp_address + 1 + sizeof(uint16_t), sizeof(sa->v4.sin_addr));    // ipv4 address is 32 bits
        return sizeof(sa->v4);
    case PROTOCOL_UDPv6:
        memset(&sa->v6, 0, sizeof(sa->v6));
        sa->s.sa_family = AF_INET6;
        sa->v6.sin6_port = port;
        memcpy(&sa->v6.sin6_addr, udp_address + 1 + sizeof(uint16_t), sizeof(sa->v6.sin6_addr)); // ipv6 address is 128 bits
        return sizeof(sa->v6);
    }
    return 0;
}

static int
send_list_udp(struct socket_server *ss, struct socket *s, struct wb_list *list, struct socket_message *result) {
    while (list->head) {
        struct write_buffer * tmp = list->head;
        union sockaddr_all sa;
        socklen_t sasz = udp_socket_address(s, tmp->udp_address, &sa);
        int err = sendto(s->fd, tmp->ptr, tmp->sz, 0, &sa.s, sasz);
        if (err < 0) {
            switch(errno) {
            case EINTR:
            case EAGAIN:
                return -1;
            }
            fprintf(stderr, "socket-server : udp (%d) sendto error %s.
",s->id, strerror(errno));
            return -1;
/*            // ignore udp sendto error

            result->opaque = s->opaque;
            result->id = s->id;
            result->ud = 0;
            result->data = NULL;

            return SOCKET_ERROR;
*/
        }

        s->wb_size -= tmp->sz;
        list->head = tmp->next;
        write_buffer_free(ss,tmp);
    }
    list->tail = NULL;

    return -1;
}

static int
send_list(struct socket_server *ss, struct socket *s, struct wb_list *list, struct socket_message *result) {
    if (s->protocol == PROTOCOL_TCP) {
        return send_list_tcp(ss, s, list, result);
    } else {
        return send_list_udp(ss, s, list, result);
    }
}

static inline int
list_uncomplete(struct wb_list *s) {
    struct write_buffer *wb = s->head;
    if (wb == NULL)
        return 0;

    return (void *)wb->ptr != wb->buffer;
}

static void
raise_uncomplete(struct socket * s) {
    struct wb_list *low = &s->low;
    struct write_buffer *tmp = low->head;
    low->head = tmp->next;
    if (low->head == NULL) {
        low->tail = NULL;
    }

    // move head of low list (tmp) to the empty high list
    struct wb_list *high = &s->high;
    assert(high->head == NULL);

    tmp->next = NULL;
    high->head = high->tail = tmp;
}

/*
    Each socket has two write buffer list, high priority and low priority.

    1. send high list as far as possible.
    2. If high list is empty, try to send low list.
    3. If low list head is uncomplete (send a part before), move the head of low list to empty high list (call raise_uncomplete) .
    4. If two lists are both empty, turn off the event. (call check_close)
 */
static int
send_buffer(struct socket_server *ss, struct socket *s, struct socket_message *result) {
    assert(!list_uncomplete(&s->low));
    // step 1
    if (send_list(ss,s,&s->high,result) == SOCKET_CLOSE) {
        return SOCKET_CLOSE;
    }
    if (s->high.head == NULL) {
        // step 2
        if (s->low.head != NULL) {
            if (send_list(ss,s,&s->low,result) == SOCKET_CLOSE) {
                return SOCKET_CLOSE;
            }
            // step 3
            if (list_uncomplete(&s->low)) {
                raise_uncomplete(s);
            }
        } else {
            // step 4
            sp_write(ss->event_fd, s->fd, s, false);

            if (s->type == SOCKET_TYPE_HALFCLOSE) {
                force_close(ss, s, result);
                return SOCKET_CLOSE;
            }
        }
    }

    return -1;
}

static struct write_buffer *
append_sendbuffer_(struct socket_server *ss, struct wb_list *s, struct request_send * request, int size, int n) {
    struct write_buffer * buf = MALLOC(size);
    struct send_object so;
    buf->userobject = send_object_init(ss, &so, request->buffer, request->sz);
    buf->ptr = (char*)so.buffer+n;
    buf->sz = so.sz - n;
    buf->buffer = request->buffer;
    buf->next = NULL;
    if (s->head == NULL) {
        s->head = s->tail = buf;
    } else {
        assert(s->tail != NULL);
        assert(s->tail->next == NULL);
        s->tail->next = buf;
        s->tail = buf;
    }
    return buf;
}

static inline void
append_sendbuffer_udp(struct socket_server *ss, struct socket *s, int priority, struct request_send * request, const uint8_t udp_address[UDP_ADDRESS_SIZE]) {
    struct wb_list *wl = (priority == PRIORITY_HIGH) ? &s->high : &s->low;
    struct write_buffer *buf = append_sendbuffer_(ss, wl, request, SIZEOF_UDPBUFFER, 0);
    memcpy(buf->udp_address, udp_address, UDP_ADDRESS_SIZE);
    s->wb_size += buf->sz;
}

static inline void
append_sendbuffer(struct socket_server *ss, struct socket *s, struct request_send * request, int n) {
    struct write_buffer *buf = append_sendbuffer_(ss, &s->high, request, SIZEOF_TCPBUFFER, n);
    s->wb_size += buf->sz;
}

static inline void
append_sendbuffer_low(struct socket_server *ss,struct socket *s, struct request_send * request) {
    struct write_buffer *buf = append_sendbuffer_(ss, &s->low, request, SIZEOF_TCPBUFFER, 0);
    s->wb_size += buf->sz;
}

static inline int
send_buffer_empty(struct socket *s) {
    return (s->high.head == NULL && s->low.head == NULL);
}

/*
    When send a package , we can assign the priority : PRIORITY_HIGH or PRIORITY_LOW

    If socket buffer is empty, write to fd directly.
        If write a part, append the rest part to high list. (Even priority is PRIORITY_LOW)
    Else append package to high (PRIORITY_HIGH) or low (PRIORITY_LOW) list.
 */
static int
send_socket(struct socket_server *ss, struct request_send * request, struct socket_message *result, int priority, const uint8_t *udp_address) {
    int id = request->id;
    struct socket * s = &ss->slot[HASH_ID(id)];
    struct send_object so;
    send_object_init(ss, &so, request->buffer, request->sz);
    if (s->type == SOCKET_TYPE_INVALID || s->id != id
        || s->type == SOCKET_TYPE_HALFCLOSE
        || s->type == SOCKET_TYPE_PACCEPT) {
        so.free_func(request->buffer);
        return -1;
    }
    assert(s->type != SOCKET_TYPE_PLISTEN && s->type != SOCKET_TYPE_LISTEN);
    if (send_buffer_empty(s) && s->type == SOCKET_TYPE_CONNECTED) {
        if (s->protocol == PROTOCOL_TCP) {
            int n = write(s->fd, so.buffer, so.sz);
            if (n<0) {
                switch(errno) {
                case EINTR:
                case EAGAIN:
                    n = 0;
                    break;
                default:
                    fprintf(stderr, "socket-server: write to %d (fd=%d) error :%s.
",id,s->fd,strerror(errno));
                    force_close(ss,s,result);
                    return SOCKET_CLOSE;
                }
            }
            if (n == so.sz) {
                so.free_func(request->buffer);
                return -1;
            }
            append_sendbuffer(ss, s, request, n);    // add to high priority list, even priority == PRIORITY_LOW
        } else {
            // udp
            if (udp_address == NULL) {
                udp_address = s->p.udp_address;
            }
            union sockaddr_all sa;
            socklen_t sasz = udp_socket_address(s, udp_address, &sa);
            int n = sendto(s->fd, so.buffer, so.sz, 0, &sa.s, sasz);
            if (n != so.sz) {
                append_sendbuffer_udp(ss,s,priority,request,udp_address);
            } else {
                so.free_func(request->buffer);
                return -1;
            }
        }
        sp_write(ss->event_fd, s->fd, s, true);
    } else {
        if (s->protocol == PROTOCOL_TCP) {
            if (priority == PRIORITY_LOW) {
                append_sendbuffer_low(ss, s, request);
            } else {
                append_sendbuffer(ss, s, request, 0);
            }
        } else {
            if (udp_address == NULL) {
                udp_address = s->p.udp_address;
            }
            append_sendbuffer_udp(ss,s,priority,request,udp_address);
        }
    }
    return -1;
}

static int
listen_socket(struct socket_server *ss, struct request_listen * request, struct socket_message *result) {
    int id = request->id;
    int listen_fd = request->fd;
    struct socket *s = new_fd(ss, id, listen_fd, PROTOCOL_TCP, request->opaque, false);
    if (s == NULL) {
        goto _failed;
    }
    s->type = SOCKET_TYPE_PLISTEN;
    return -1;
_failed:
    close(listen_fd);
    result->opaque = request->opaque;
    result->id = id;
    result->ud = 0;
    result->data = NULL;
    ss->slot[HASH_ID(id)].type = SOCKET_TYPE_INVALID;

    return SOCKET_ERROR;
}

static int
close_socket(struct socket_server *ss, struct request_close *request, struct socket_message *result) {
    int id = request->id;
    struct socket * s = &ss->slot[HASH_ID(id)];
    if (s->type == SOCKET_TYPE_INVALID || s->id != id) {
        result->id = id;
        result->opaque = request->opaque;
        result->ud = 0;
        result->data = NULL;
        return SOCKET_CLOSE;
    }
    if (!send_buffer_empty(s)) {
        int type = send_buffer(ss,s,result);
        if (type != -1)
            return type;
    }
    if (send_buffer_empty(s)) {
        force_close(ss,s,result);
        result->id = id;
        result->opaque = request->opaque;
        return SOCKET_CLOSE;
    }
    s->type = SOCKET_TYPE_HALFCLOSE;

    return -1;
}

static int
bind_socket(struct socket_server *ss, struct request_bind *request, struct socket_message *result) {
    int id = request->id;
    result->id = id;
    result->opaque = request->opaque;
    result->ud = 0;
    struct socket *s = new_fd(ss, id, request->fd, PROTOCOL_TCP, request->opaque, true);
    if (s == NULL) {
        result->data = NULL;
        return SOCKET_ERROR;
    }
    sp_nonblocking(request->fd);
    s->type = SOCKET_TYPE_BIND;
    result->data = "binding";
    return SOCKET_OPEN;
}

static int
start_socket(struct socket_server *ss, struct request_start *request, struct socket_message *result) {
    int id = request->id;
    result->id = id;
    result->opaque = request->opaque;
    result->ud = 0;
    result->data = NULL;
    struct socket *s = &ss->slot[HASH_ID(id)];
    if (s->type == SOCKET_TYPE_INVALID || s->id !=id) {
        return SOCKET_ERROR;
    }
    if (s->type == SOCKET_TYPE_PACCEPT || s->type == SOCKET_TYPE_PLISTEN) {
        if (sp_add(ss->event_fd, s->fd, s)) {
            s->type = SOCKET_TYPE_INVALID;
            return SOCKET_ERROR;
        }
        s->type = (s->type == SOCKET_TYPE_PACCEPT) ? SOCKET_TYPE_CONNECTED : SOCKET_TYPE_LISTEN;
        s->opaque = request->opaque;
        result->data = "start";
        return SOCKET_OPEN;
    } else if (s->type == SOCKET_TYPE_CONNECTED) {
        s->opaque = request->opaque;
        result->data = "transfer";
        return SOCKET_OPEN;
    }
    return -1;
}

static void
setopt_socket(struct socket_server *ss, struct request_setopt *request) {
    int id = request->id;
    struct socket *s = &ss->slot[HASH_ID(id)];
    if (s->type == SOCKET_TYPE_INVALID || s->id !=id) {
        return;
    }
    int v = request->value;
    setsockopt(s->fd, IPPROTO_TCP, request->what, &v, sizeof(v));
}

static void
block_readpipe(int pipefd, void *buffer, int sz) {
    for (;;) {
        int n = read(pipefd, buffer, sz);
        if (n<0) {
            if (errno == EINTR)
                continue;
            fprintf(stderr, "socket-server : read pipe error %s.
",strerror(errno));
            return;
        }
        // must atomic read from a pipe
        assert(n == sz);
        return;
    }
}

static int
has_cmd(struct socket_server *ss) {
    struct timeval tv = {0,0};
    int retval;

    FD_SET(ss->recvctrl_fd, &ss->rfds);

    retval = select(ss->recvctrl_fd+1, &ss->rfds, NULL, NULL, &tv);
    if (retval == 1) {
        return 1;
    }
    return 0;
}

static void
add_udp_socket(struct socket_server *ss, struct request_udp *udp) {
    int id = udp->id;
    int protocol;
    if (udp->family == AF_INET6) {
        protocol = PROTOCOL_UDPv6;
    } else {
        protocol = PROTOCOL_UDP;
    }
    struct socket *ns = new_fd(ss, id, udp->fd, protocol, udp->opaque, true);
    if (ns == NULL) {
        close(udp->fd);
        ss->slot[HASH_ID(id)].type = SOCKET_TYPE_INVALID;
        return;
    }
    ns->type = SOCKET_TYPE_CONNECTED;
    memset(ns->p.udp_address, 0, sizeof(ns->p.udp_address));
}

static int
set_udp_address(struct socket_server *ss, struct request_setudp *request, struct socket_message *result) {
    int id = request->id;
    struct socket *s = &ss->slot[HASH_ID(id)];
    if (s->type == SOCKET_TYPE_INVALID || s->id !=id) {
        return -1;
    }
    int type = request->address[0];
    if (type != s->protocol) {
        // protocol mismatch
        result->opaque = s->opaque;
        result->id = s->id;
        result->ud = 0;
        result->data = NULL;

        return SOCKET_ERROR;
    }
    if (type == PROTOCOL_UDP) {
        memcpy(s->p.udp_address, request->address, 1+2+4);    // 1 type, 2 port, 4 ipv4
    } else {
        memcpy(s->p.udp_address, request->address, 1+2+16);    // 1 type, 2 port, 16 ipv6
    }
    return -1;
}

// return type
static int
ctrl_cmd(struct socket_server *ss, struct socket_message *result) {
    int fd = ss->recvctrl_fd;
    // the length of message is one byte, so 256+8 buffer size is enough.
    uint8_t buffer[256];
    uint8_t header[2];
    block_readpipe(fd, header, sizeof(header));
    int type = header[0];
    int len = header[1];
    block_readpipe(fd, buffer, len);
    // ctrl command only exist in local fd, so don't worry about endian.
    switch (type) {
    case 'S':
        return start_socket(ss,(struct request_start *)buffer, result);
    case 'B':
        return bind_socket(ss,(struct request_bind *)buffer, result);
    case 'L':
        return listen_socket(ss,(struct request_listen *)buffer, result);
    case 'K':
        return close_socket(ss,(struct request_close *)buffer, result);
    case 'O':
        return open_socket(ss, (struct request_open *)buffer, result);
    case 'X':
        result->opaque = 0;
        result->id = 0;
        result->ud = 0;
        result->data = NULL;
        return SOCKET_EXIT;
    case 'D':
        return send_socket(ss, (struct request_send *)buffer, result, PRIORITY_HIGH, NULL);
    case 'P':
        return send_socket(ss, (struct request_send *)buffer, result, PRIORITY_LOW, NULL);
    case 'A': {
        struct request_send_udp * rsu = (struct request_send_udp *)buffer;
        return send_socket(ss, &rsu->send, result, PRIORITY_HIGH, rsu->address);
    }
    case 'C':
        return set_udp_address(ss, (struct request_setudp *)buffer, result);
    case 'T':
        setopt_socket(ss, (struct request_setopt *)buffer);
        return -1;
    case 'U':
        add_udp_socket(ss, (struct request_udp *)buffer);
        return -1;
    default:
        fprintf(stderr, "socket-server: Unknown ctrl %c.
",type);
        return -1;
    };

    return -1;
}

// return -1 (ignore) when error
static int
forward_message_tcp(struct socket_server *ss, struct socket *s, struct socket_message * result) {
    int sz = s->p.size;
    char * buffer = MALLOC(sz);
    int n = (int)read(s->fd, buffer, sz);
    if (n<0) {
        FREE(buffer);
        switch(errno) {
        case EINTR:
            break;
        case EAGAIN:
            fprintf(stderr, "socket-server: EAGAIN capture.
");
            break;
        default:
            // close when error
            force_close(ss, s, result);
            return SOCKET_ERROR;
        }
        return -1;
    }
    if (n==0) {
        FREE(buffer);
        force_close(ss, s, result);
        return SOCKET_CLOSE;
    }

    if (s->type == SOCKET_TYPE_HALFCLOSE) {
        // discard recv data
        FREE(buffer);
        return -1;
    }

    if (n == sz) {
        s->p.size *= 2;
    } else if (sz > MIN_READ_BUFFER && n*2 < sz) {
        s->p.size /= 2;
    }

    result->opaque = s->opaque;
    result->id = s->id;
    result->ud = n;
    result->data = buffer;
    return SOCKET_DATA;
}

static int
gen_udp_address(int protocol, union sockaddr_all *sa, uint8_t * udp_address) {
    int addrsz = 1;
    udp_address[0] = (uint8_t)protocol;
    if (protocol == PROTOCOL_UDP) {
        memcpy(udp_address+addrsz, &sa->v4.sin_port, sizeof(sa->v4.sin_port));
        addrsz += sizeof(sa->v4.sin_port);
        memcpy(udp_address+addrsz, &sa->v4.sin_addr, sizeof(sa->v4.sin_addr));
        addrsz += sizeof(sa->v4.sin_addr);
    } else {
        memcpy(udp_address+addrsz, &sa->v6.sin6_port, sizeof(sa->v6.sin6_port));
        addrsz += sizeof(sa->v6.sin6_port);
        memcpy(udp_address+addrsz, &sa->v6.sin6_addr, sizeof(sa->v6.sin6_addr));
        addrsz += sizeof(sa->v6.sin6_addr);
    }
    return addrsz;
}

static int
forward_message_udp(struct socket_server *ss, struct socket *s, struct socket_message * result) {
    union sockaddr_all sa;
    socklen_t slen = sizeof(sa);
    int n = recvfrom(s->fd, ss->udpbuffer,MAX_UDP_PACKAGE,0,&sa.s,&slen);
    if (n<0) {
        switch(errno) {
        case EINTR:
        case EAGAIN:
            break;
        default:
            // close when error
            force_close(ss, s, result);
            return SOCKET_ERROR;
        }
        return -1;
    }
    uint8_t * data;
    if (slen == sizeof(sa.v4)) {
        if (s->protocol != PROTOCOL_UDP)
            return -1;
        data = MALLOC(n + 1 + 2 + 4);
        gen_udp_address(PROTOCOL_UDP, &sa, data + n);
    } else {
        if (s->protocol != PROTOCOL_UDPv6)
            return -1;
        data = MALLOC(n + 1 + 2 + 16);
        gen_udp_address(PROTOCOL_UDPv6, &sa, data + n);
    }
    memcpy(data, ss->udpbuffer, n);

    result->opaque = s->opaque;
    result->id = s->id;
    result->ud = n;
    result->data = (char *)data;

    return SOCKET_UDP;
}

static int
report_connect(struct socket_server *ss, struct socket *s, struct socket_message *result) {
    int error;
    socklen_t len = sizeof(error);
    int code = getsockopt(s->fd, SOL_SOCKET, SO_ERROR, &error, &len);
    if (code < 0 || error) {
        force_close(ss,s, result);
        return SOCKET_ERROR;
    } else {
        s->type = SOCKET_TYPE_CONNECTED;
        result->opaque = s->opaque;
        result->id = s->id;
        result->ud = 0;
        if (send_buffer_empty(s)) {
            sp_write(ss->event_fd, s->fd, s, false);
        }
        union sockaddr_all u;
        socklen_t slen = sizeof(u);
        if (getpeername(s->fd, &u.s, &slen) == 0) {
            void * sin_addr = (u.s.sa_family == AF_INET) ? (void*)&u.v4.sin_addr : (void *)&u.v6.sin6_addr;
            if (inet_ntop(u.s.sa_family, sin_addr, ss->buffer, sizeof(ss->buffer))) {
                result->data = ss->buffer;
                return SOCKET_OPEN;
            }
        }
        result->data = NULL;
        return SOCKET_OPEN;
    }
}

// return 0 when failed
static int
report_accept(struct socket_server *ss, struct socket *s, struct socket_message *result) {
    union sockaddr_all u;
    socklen_t len = sizeof(u);
    int client_fd = accept(s->fd, &u.s, &len);
    if (client_fd < 0) {
        return 0;
    }
    int id = reserve_id(ss);
    if (id < 0) {
        close(client_fd);
        return 0;
    }
    socket_keepalive(client_fd);
    sp_nonblocking(client_fd);
    struct socket *ns = new_fd(ss, id, client_fd, PROTOCOL_TCP, s->opaque, false);
    if (ns == NULL) {
        close(client_fd);
        return 0;
    }
    ns->type = SOCKET_TYPE_PACCEPT;
    result->opaque = s->opaque;
    result->id = s->id;
    result->ud = id;
    result->data = NULL;

    void * sin_addr = (u.s.sa_family == AF_INET) ? (void*)&u.v4.sin_addr : (void *)&u.v6.sin6_addr;
    int sin_port = ntohs((u.s.sa_family == AF_INET) ? u.v4.sin_port : u.v6.sin6_port);
    char tmp[INET6_ADDRSTRLEN];
    if (inet_ntop(u.s.sa_family, sin_addr, tmp, sizeof(tmp))) {
        snprintf(ss->buffer, sizeof(ss->buffer), "%s:%d", tmp, sin_port);
        result->data = ss->buffer;
    }

    return 1;
}

static inline void
clear_closed_event(struct socket_server *ss, struct socket_message * result, int type) {
    if (type == SOCKET_CLOSE || type == SOCKET_ERROR) {
        int id = result->id;
        int i;
        for (i=ss->event_index; i<ss->event_n; i++) {
            struct event *e = &ss->ev[i];
            struct socket *s = e->s;
            if (s) {
                if (s->type == SOCKET_TYPE_INVALID && s->id == id) {
                    e->s = NULL;
                }
            }
        }
    }
}

// return type
int
socket_server_poll(struct socket_server *ss, struct socket_message * result, int * more) {
    for (;;) {
        if (ss->checkctrl) {
            if (has_cmd(ss)) {
                int type = ctrl_cmd(ss, result);
                if (type != -1) {
                    clear_closed_event(ss, result, type);
                    return type;
                } else
                    continue;
            } else {
                ss->checkctrl = 0;
            }
        }
        if (ss->event_index == ss->event_n) {
            ss->event_n = sp_wait(ss->event_fd, ss->ev, MAX_EVENT);
            ss->checkctrl = 1;
            if (more) {
                *more = 0;
            }
            ss->event_index = 0;
            if (ss->event_n <= 0) {
                ss->event_n = 0;
                return -1;
            }
        }
        struct event *e = &ss->ev[ss->event_index++];
        struct socket *s = e->s;
        if (s == NULL) {
            // dispatch pipe message at beginning
            continue;
        }
        switch (s->type) {
        case SOCKET_TYPE_CONNECTING:
            return report_connect(ss, s, result);
        case SOCKET_TYPE_LISTEN:
            if (report_accept(ss, s, result)) {
                return SOCKET_ACCEPT;
            }
            break;
        case SOCKET_TYPE_INVALID:
            fprintf(stderr, "socket-server: invalid socket
");
            break;
        default:
            printf("get new message from client
");
            if (e->read) {
                int type;
                if (s->protocol == PROTOCOL_TCP) {
                    type = forward_message_tcp(ss, s, result);
                } else {
                    type = forward_message_udp(ss, s, result);
                    if (type == SOCKET_UDP) {
                        // try read again
                        --ss->event_index;
                        return SOCKET_UDP;
                    }
                }
                if (e->write) {
                    // Try to dispatch write message next step if write flag set.
                    e->read = false;
                    --ss->event_index;
                }
                if (type == -1)
                    break;
                clear_closed_event(ss, result, type);
                return type;
            }
            if (e->write) {
                int type = send_buffer(ss, s, result);
                if (type == -1)
                    break;
                clear_closed_event(ss, result, type);
                return type;
            }
            break;
        }
    }
}

static void
send_request(struct socket_server *ss, struct request_package *request, char type, int len) {
    request->header[6] = (uint8_t)type;
    request->header[7] = (uint8_t)len;
    for (;;) {
        int n = write(ss->sendctrl_fd, &request->header[6], len+2);
        if (n<0) {
            if (errno != EINTR) {
                fprintf(stderr, "socket-server : send ctrl command error %s.
", strerror(errno));
            }
            continue;
        }
        assert(n == len+2);
        return;
    }
}

static int
open_request(struct socket_server *ss, struct request_package *req, uintptr_t opaque, const char *addr, int port) {
    int len = strlen(addr);
    if (len + sizeof(req->u.open) > 256) {
        fprintf(stderr, "socket-server : Invalid addr %s.
",addr);
        return -1;
    }
    int id = reserve_id(ss);
    if (id < 0)
        return -1;
    req->u.open.opaque = opaque;
    req->u.open.id = id;
    req->u.open.port = port;
    memcpy(req->u.open.host, addr, len);
    req->u.open.host[len] = '';

    return len;
}

int
socket_server_connect(struct socket_server *ss, uintptr_t opaque, const char * addr, int port) {
    struct request_package request;
    int len = open_request(ss, &request, opaque, addr, port);
    if (len < 0)
        return -1;
    send_request(ss, &request, 'O', sizeof(request.u.open) + len);
    return request.u.open.id;
}

static void
free_buffer(struct socket_server *ss, const void * buffer, int sz) {
    struct send_object so;
    send_object_init(ss, &so, (void *)buffer, sz);
    so.free_func((void *)buffer);
}

// return -1 when error
int64_t
socket_server_send(struct socket_server *ss, int id, const void * buffer, int sz) {
    struct socket * s = &ss->slot[HASH_ID(id)];
    if (s->id != id || s->type == SOCKET_TYPE_INVALID) {
        free_buffer(ss, buffer, sz);
        return -1;
    }

    struct request_package request;
    request.u.send.id = id;
    request.u.send.sz = sz;
    request.u.send.buffer = (char *)buffer;

    send_request(ss, &request, 'D', sizeof(request.u.send));
    return s->wb_size;
}

void
socket_server_send_lowpriority(struct socket_server *ss, int id, const void * buffer, int sz) {
    struct socket * s = &ss->slot[HASH_ID(id)];
    if (s->id != id || s->type == SOCKET_TYPE_INVALID) {
        free_buffer(ss, buffer, sz);
        return;
    }

    struct request_package request;
    request.u.send.id = id;
    request.u.send.sz = sz;
    request.u.send.buffer = (char *)buffer;

    send_request(ss, &request, 'P', sizeof(request.u.send));
}

void
socket_server_exit(struct socket_server *ss) {
    struct request_package request;
    send_request(ss, &request, 'X', 0);
}

void
socket_server_close(struct socket_server *ss, uintptr_t opaque, int id) {
    struct request_package request;
    request.u.close.id = id;
    request.u.close.opaque = opaque;
    send_request(ss, &request, 'K', sizeof(request.u.close));
}

// return -1 means failed
// or return AF_INET or AF_INET6
static int
do_bind(const char *host, int port, int protocol, int *family) {
    int fd;
    int status;
    int reuse = 1;
    struct addrinfo ai_hints;
    struct addrinfo *ai_list = NULL;
    char portstr[16];
    if (host == NULL || host[0] == 0) {
        host = "0.0.0.0";    // INADDR_ANY
    }
    sprintf(portstr, "%d", port);
    memset( &ai_hints, 0, sizeof( ai_hints ) );
    ai_hints.ai_family = AF_UNSPEC;
    if (protocol == IPPROTO_TCP) {
        ai_hints.ai_socktype = SOCK_STREAM;
    } else {
        assert(protocol == IPPROTO_UDP);
        ai_hints.ai_socktype = SOCK_DGRAM;
    }
    ai_hints.ai_protocol = protocol;

    status = getaddrinfo( host, portstr, &ai_hints, &ai_list );
    if ( status != 0 ) {
        return -1;
    }
    *family = ai_list->ai_family;
    fd = socket(*family, ai_list->ai_socktype, 0);
    if (fd < 0) {
        goto _failed_fd;
    }
    if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, (void *)&reuse, sizeof(int))==-1) {
        goto _failed;
    }
    status = bind(fd, (struct sockaddr *)ai_list->ai_addr, ai_list->ai_addrlen);
    if (status != 0)
        goto _failed;

    freeaddrinfo( ai_list );
    return fd;
_failed:
    close(fd);
_failed_fd:
    freeaddrinfo( ai_list );
    return -1;
}

static int
do_listen(const char * host, int port, int backlog) {
    int family = 0;
    int listen_fd = do_bind(host, port, IPPROTO_TCP, &family);
    if (listen_fd < 0) {
        return -1;
    }
    if (listen(listen_fd, backlog) == -1) {
        close(listen_fd);
        return -1;
    }
    return listen_fd;
}

int
socket_server_listen(struct socket_server *ss, uintptr_t opaque, const char * addr, int port, int backlog) {
    int fd = do_listen(addr, port, backlog);
    if (fd < 0) {
        return -1;
    }
    struct request_package request;
    int id = reserve_id(ss);
    if (id < 0) {
        close(fd);
        return id;
    }
    request.u.listen.opaque = opaque;
    request.u.listen.id = id;
    request.u.listen.fd = fd;
    send_request(ss, &request, 'L', sizeof(request.u.listen));
    return id;
}

int
socket_server_bind(struct socket_server *ss, uintptr_t opaque, int fd) {
    struct request_package request;
    int id = reserve_id(ss);
    if (id < 0)
        return -1;
    request.u.bind.opaque = opaque;
    request.u.bind.id = id;
    request.u.bind.fd = fd;
    send_request(ss, &request, 'B', sizeof(request.u.bind));
    return id;
}

void
socket_server_start(struct socket_server *ss, uintptr_t opaque, int id) {
    struct request_package request;
    request.u.start.id = id;
    request.u.start.opaque = opaque;
    send_request(ss, &request, 'S', sizeof(request.u.start));
}

void
socket_server_nodelay(struct socket_server *ss, int id) {
    struct request_package request;
    request.u.setopt.id = id;
    request.u.setopt.what = TCP_NODELAY;
    request.u.setopt.value = 1;
    send_request(ss, &request, 'T', sizeof(request.u.setopt));
}

void
socket_server_userobject(struct socket_server *ss, struct socket_object_interface *soi) {
    ss->soi = *soi;
}

// UDP

int
socket_server_udp(struct socket_server *ss, uintptr_t opaque, const char * addr, int port) {
    int fd;
    int family;
    if (port != 0 || addr != NULL) {
        // bind
        fd = do_bind(addr, port, IPPROTO_UDP, &family);
        if (fd < 0) {
            return -1;
        }
    } else {
        family = AF_INET;
        fd = socket(family, SOCK_DGRAM, 0);
        if (fd < 0) {
            return -1;
        }
    }
    sp_nonblocking(fd);

    int id = reserve_id(ss);
    if (id < 0) {
        close(fd);
        return -1;
    }
    struct request_package request;
    request.u.udp.id = id;
    request.u.udp.fd = fd;
    request.u.udp.opaque = opaque;
    request.u.udp.family = family;

    send_request(ss, &request, 'U', sizeof(request.u.udp));
    return id;
}

int64_t
socket_server_udp_send(struct socket_server *ss, int id, const struct socket_udp_address *addr, const void *buffer, int sz) {
    struct socket * s = &ss->slot[HASH_ID(id)];
    if (s->id != id || s->type == SOCKET_TYPE_INVALID) {
        free_buffer(ss, buffer, sz);
        return -1;
    }

    struct request_package request;
    request.u.send_udp.send.id = id;
    request.u.send_udp.send.sz = sz;
    request.u.send_udp.send.buffer = (char *)buffer;

    const uint8_t *udp_address = (const uint8_t *)addr;
    int addrsz;
    switch (udp_address[0]) {
    case PROTOCOL_UDP:
        addrsz = 1+2+4;        // 1 type, 2 port, 4 ipv4
        break;
    case PROTOCOL_UDPv6:
        addrsz = 1+2+16;    // 1 type, 2 port, 16 ipv6
        break;
    default:
        free_buffer(ss, buffer, sz);
        return -1;
    }

    memcpy(request.u.send_udp.address, udp_address, addrsz);

    send_request(ss, &request, 'A', sizeof(request.u.send_udp.send)+addrsz);
    return s->wb_size;
}

int
socket_server_udp_connect(struct socket_server *ss, int id, const char * addr, int port) {
    int status;
    struct addrinfo ai_hints;
    struct addrinfo *ai_list = NULL;
    char portstr[16];
    sprintf(portstr, "%d", port);
    memset( &ai_hints, 0, sizeof( ai_hints ) );
    ai_hints.ai_family = AF_UNSPEC;
    ai_hints.ai_socktype = SOCK_DGRAM;
    ai_hints.ai_protocol = IPPROTO_UDP;

    status = getaddrinfo(addr, portstr, &ai_hints, &ai_list );
    if ( status != 0 ) {
        return -1;
    }
    struct request_package request;
    request.u.set_udp.id = id;
    int protocol;

    if (ai_list->ai_family == AF_INET) {
        protocol = PROTOCOL_UDP;
    } else if (ai_list->ai_family == AF_INET6) {
        protocol = PROTOCOL_UDPv6;
    } else {
        freeaddrinfo( ai_list );
        return -1;
    }

    int addrsz = gen_udp_address(protocol, (union sockaddr_all *)ai_list->ai_addr, request.u.set_udp.address);

    freeaddrinfo( ai_list );

    send_request(ss, &request, 'C', sizeof(request.u.set_udp) - sizeof(request.u.set_udp.address) +addrsz);

    return 0;
}

const struct socket_udp_address *
socket_server_udp_address(struct socket_server *ss, struct socket_message *msg, int *addrsz) {
    uint8_t * address = (uint8_t *)(msg->data + msg->ud);
    int type = address[0];
    switch(type) {
    case PROTOCOL_UDP:
        *addrsz = 1+2+4;
        break;
    case PROTOCOL_UDPv6:
        *addrsz = 1+2+16;
        break;
    default:
        return NULL;
    }
    return (const struct socket_udp_address *)address;
}
View Code

socket_server.h

#ifndef skynet_socket_server_h
#define skynet_socket_server_h

#include <stdint.h>

// custom malloc/free
//#define SOCKET_SERVER_FILE_MEMAPI   skynet.h
//#define SOCKET_SERVER_MALLOC        skynet_malloc
//#define SOCKET_SERVER_FREE          skynet_free

#define SOCKET_DATA 0
#define SOCKET_CLOSE 1
#define SOCKET_OPEN 2
#define SOCKET_ACCEPT 3
#define SOCKET_ERROR 4
#define SOCKET_EXIT 5
#define SOCKET_UDP 6

struct socket_server;

struct socket_message {
    int id;
    uintptr_t opaque;
    int ud;    // for accept, ud is listen id ; for data, ud is size of data
    char * data;
};

struct socket_server * socket_server_create();
void socket_server_release(struct socket_server *);
int socket_server_poll(struct socket_server *, struct socket_message *result, int *more);

void socket_server_exit(struct socket_server *);
void socket_server_close(struct socket_server *, uintptr_t opaque, int id);
void socket_server_start(struct socket_server *, uintptr_t opaque, int id);

// return -1 when error
int64_t socket_server_send(struct socket_server *, int id, const void * buffer, int sz);
void socket_server_send_lowpriority(struct socket_server *, int id, const void * buffer, int sz);

// ctrl command below returns id
int socket_server_listen(struct socket_server *, uintptr_t opaque, const char * addr, int port, int backlog);
int socket_server_connect(struct socket_server *, uintptr_t opaque, const char * addr, int port);
int socket_server_bind(struct socket_server *, uintptr_t opaque, int fd);

// for tcp
void socket_server_nodelay(struct socket_server *, int id);

struct socket_udp_address;

// create an udp socket handle, attach opaque with it . udp socket don't need call socket_server_start to recv message
// if port != 0, bind the socket . if addr == NULL, bind ipv4 0.0.0.0 . If you want to use ipv6, addr can be "::" and port 0.
int socket_server_udp(struct socket_server *, uintptr_t opaque, const char * addr, int port);
// set default dest address, return 0 when success
int socket_server_udp_connect(struct socket_server *, int id, const char * addr, int port);
// If the socket_udp_address is NULL, use last call socket_server_udp_connect address instead
// You can also use socket_server_send
int64_t socket_server_udp_send(struct socket_server *, int id, const struct socket_udp_address *, const void *buffer, int sz);
// extract the address of the message, struct socket_message * should be SOCKET_UDP
const struct socket_udp_address * socket_server_udp_address(struct socket_server *, struct socket_message *, int *addrsz);

struct socket_object_interface {
    void * (*buffer)(void *);
    int (*size)(void *);
    void (*free)(void *);
};

// if you send package sz == -1, use soi.
void socket_server_userobject(struct socket_server *, struct socket_object_interface *soi);

#endif
View Code

test.c

#include "socket_server.h"
#include "message_resolve.h"
#include <pthread.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <signal.h>
#include <string.h>

static void 
unpack(struct socket_message *result, struct my_data *md)
{
    int ret = filter_data_((uint8_t*)result->data, result->ud, result->id, md);
    if (0 == ret) //get a complete package
    {
        printf("size[%d], data[%s]
", md->size, (char *)md->data);
        free(md->data);
        md->data = NULL;//
        md->size = 0;
    }
    else
    {
        printf("not a completed package
");
    }
    free(result->data);
}

static void *
_poll(void * ud) {
    struct socket_server *ss = ud;
    struct socket_message result; 
    struct my_data md;

    for (;;) {
        int type = socket_server_poll(ss, &result, NULL);
        // DO NOT use any ctrl command (socket_server_close , etc. ) in this thread.
        switch (type) {
        case SOCKET_EXIT:
            return NULL;
        case SOCKET_DATA:
            printf("SOCKET_DATA
");

            unpack(&result, &md);
        
            break;
        case SOCKET_CLOSE:
            printf("close(%lu) [id=%d]
",result.opaque,result.id);
            break;
        case SOCKET_OPEN:
            printf("open(%lu) [id=%d] %s
",result.opaque,result.id,result.data);
            break;
        case SOCKET_ERROR:
            printf("error(%lu) [id=%d]
",result.opaque,result.id);
            break;
        case SOCKET_ACCEPT:
            printf("accept(%lu) [id=%d %s] from [%d]
",result.opaque, result.ud, result.data, result.id);
            socket_server_start(ss, 301, result.ud);
            break;
        }
    }
}

static void
test(struct socket_server *ss) {
    pthread_t pid;
    pthread_create(&pid, NULL, _poll, ss);

/*    int c = socket_server_connect(ss,100,"127.0.0.1",80);
    printf("connecting %d
",c);*/
    int l = socket_server_listen(ss,200,"127.0.0.1",8888,32);
    printf("listening %d
",l);
    socket_server_start(ss,201,l);
    /*int b = socket_server_bind(ss,300,1);
    printf("binding stdin %d
",b);
    int i;
    for (i=0;i<100;i++) {
        socket_server_connect(ss, 400+i, "127.0.0.1", 8888);
    }
    sleep(5);
    socket_server_exit(ss);*/

    pthread_join(pid, NULL); 
}


int
main() {
    struct sigaction sa;
    sa.sa_handler = SIG_IGN;
    sigaction(SIGPIPE, &sa, 0);

    init_queue();
    struct socket_server * ss = socket_server_create();
    test(ss);
    socket_server_release(ss);

    return 0;
}
View Code

5.以上代码,可先从 https://github.com/cloudwu/socket-server 下载得到,然后直接加上message_resolve.c/h ,替换test.c即可。

 message_resolve.c 可参考源码 https://github.com/cloudwu/skynet的源码中的 lua-netpack.c。

6.水平有限,仅限学习交流,欢迎及时指正错误。

  

    

原文地址:https://www.cnblogs.com/newbeeyu/p/6906691.html