网络代理之HTTP代理(golang反向代理、负载均衡算法实现)

网络代理于网络转发区别

网络代理:

用户不直接连接服务器,网络代理去连接,获取数据后返回给用户

网络转发:

是路由器对报文的转发操作,中间可能对数据包修改

网络代理类型:

 

正向代理:

 实现一个web浏览器代理:

 代码实现一个web浏览器代理:

 代码实现:

package main

import (
    "fmt"
    "io"
    "net"
    "net/http"
    "strings"
)

type Pxy struct{}

func (p *Pxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
    fmt.Printf("Received request %s %s %s
", req.Method, req.Host, req.RemoteAddr)
    transport := http.DefaultTransport
    // step 1,浅拷贝对象,然后就再新增属性数据
    outReq := new(http.Request)
    *outReq = *req
    if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
        if prior, ok := outReq.Header["X-Forwarded-For"]; ok {
            clientIP = strings.Join(prior, ", ") + ", " + clientIP
        }
        outReq.Header.Set("X-Forwarded-For", clientIP)
    }
    
    // step 2, 请求下游
    res, err := transport.RoundTrip(outReq)
    if err != nil {
        rw.WriteHeader(http.StatusBadGateway)
        return
    }

    // step 3, 把下游请求内容返回给上游
    for key, value := range res.Header {
        for _, v := range value {
            rw.Header().Add(key, v)
        }
    }
    rw.WriteHeader(res.StatusCode)
    io.Copy(rw, res.Body)
    res.Body.Close()
}

func main() {
    fmt.Println("Serve on :8080")
    http.Handle("/", &Pxy{})
    http.ListenAndServe("0.0.0.0:8080", nil)
}

反向代理:

如何实现一个反向代理:

  • 这个功能比较复杂,我们先实现一个简版的http反向代理。
  • 代理接收客户端请求,更改请求结构体信息
  • 通过一定的负载均衡算法获取下游服务地址
  • 把请求发送到下游服务器,并获取返回的内容
  • 对返回的内容做一些处理,然后返回给客户端

启动两个http服务(真是服务地址)

127.0.0.1:2003
127.0.0.1:2004
package main

import (
    "fmt"
    "io"
    "log"
    "net/http"
    "os"
    "os/signal"
    "syscall"
    "time"
)

func main() {
    rs1 := &RealServer{Addr: "127.0.0.1:2003"}
    rs1.Run()
    rs2 := &RealServer{Addr: "127.0.0.1:2004"}
    rs2.Run()

    //监听关闭信号
    quit := make(chan os.Signal)
    signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
    <-quit
}

type RealServer struct {
    Addr string
}

func (r *RealServer) Run() {
    log.Println("Starting httpserver at " + r.Addr)
    mux := http.NewServeMux()
    mux.HandleFunc("/", r.HelloHandler)
    mux.HandleFunc("/base/error", r.ErrorHandler)
    server := &http.Server{
        Addr:         r.Addr,
        WriteTimeout: time.Second * 3,
        Handler:      mux,
    }
    go func() {
        log.Fatal(server.ListenAndServe())
    }()
}

func (r *RealServer) HelloHandler(w http.ResponseWriter, req *http.Request) {
    //127.0.0.1:8008/abc?sdsdsa=11
    //r.Addr=127.0.0.1:8008
    //req.URL.Path=/abc
    fmt.Println(req.Host)
    upath := fmt.Sprintf("http://%s%s
", r.Addr, req.URL.Path)
    realIP := fmt.Sprintf("RemoteAddr=%s,X-Forwarded-For=%v,X-Real-Ip=%v
", req.RemoteAddr, req.Header.Get("X-Forwarded-For"), req.Header.Get("X-Real-Ip"))

    io.WriteString(w, upath)
    io.WriteString(w, realIP)
}

func (r *RealServer) ErrorHandler(w http.ResponseWriter, req *http.Request) {
    upath := "error handler"
    w.WriteHeader(500)
    io.WriteString(w, upath)
}
real_server

启动一个代理服务

代理服务 127.0.0.1:2002(此处代码并没有使用负载均衡算法,只是简单地固定代理到其中一个服务器)

package main

import (
    "bufio"
    "log"
    "net/http"
    "net/url"
)

var (
    proxy_addr = "http://127.0.0.1:2003"
    port       = "2002"
)

func handler(w http.ResponseWriter, r *http.Request) {
    //step 1 解析代理地址,并更改请求体的协议和主机
    proxy, err := url.Parse(proxy_addr)
    r.URL.Scheme = proxy.Scheme
    r.URL.Host = proxy.Host

    //step 2 请求下游
    transport := http.DefaultTransport
    resp, err := transport.RoundTrip(r)
    if err != nil {
        log.Print(err)
        return
    }

    //step 3 把下游请求内容返回给上游
    for k, vv := range resp.Header {
        for _, v := range vv {
            w.Header().Add(k, v)
        }
    }
    defer resp.Body.Close()
    bufio.NewReader(resp.Body).WriteTo(w)
}

func main() {
    http.HandleFunc("/", handler)
    log.Println("Start serving on port " + port)
    err := http.ListenAndServe(":"+port, nil)
    if err != nil {
        log.Fatal(err)
    }
}
reverse_proxy

用户访问127.0.0.1:2002   反向代理到  127.0.0.1:2003

http代理

上面演示的是一个简版的http代理,不具备一下功能:

  • 错误回调及错误日志处理
  • 更改代理返回内容
  • 负载均衡
  • url重写
  • 限流、熔断、降级

用golang官方提供的ReverseProxy实现一个http代理

  • ReverseProxy功能点
  • ReverseProxy实例
  • ReverseProxy源码实现

拓展ReverseProxy功能

  • 4中负载轮训类型实现以及接口封装
  • 拓展中间件支持:限流、熔断实现、权限、数据统计

用ReverseProxy实现一个http代理:

package main

import (
    "log"
    "net/http"
    "net/http/httputil"
    "net/url"
)

var addr = "127.0.0.1:2002"

func main() {
    //127.0.0.1:2002/xxx  => 127.0.0.1:2003/base/xxx
    //127.0.0.1:2003/base/xxx
    rs1 := "http://127.0.0.1:2003/base"
    url1, err1 := url.Parse(rs1)
    if err1 != nil {
        log.Println(err1)
    }
    proxy := httputil.NewSingleHostReverseProxy(url1)
    log.Println("Starting httpserver at " + addr)
    log.Fatal(http.ListenAndServe(addr, proxy))
}

ReverseProxy修改返回的内容

重写 

httputil.NewSingleHostReverseProxy(url1)
package main

import (
	"bytes"
	"errors"
	"fmt"
	"io/ioutil"
	"log"
	"net/http"
	"net/http/httputil"
	"net/url"
	"regexp"
	"strings"
)

var addr = "127.0.0.1:2002"

func main() {
	//127.0.0.1:2002/xxx
	//127.0.0.1:2003/base/xxx
	rs1 := "http://127.0.0.1:2003/base"
	url1, err1 := url.Parse(rs1)
	if err1 != nil {
		log.Println(err1)
	}
	proxy := NewSingleHostReverseProxy(url1)
	log.Println("Starting httpserver at " + addr)
	log.Fatal(http.ListenAndServe(addr, proxy))
}

func NewSingleHostReverseProxy(target *url.URL) *httputil.ReverseProxy {
	//http://127.0.0.1:2002/dir?name=123
	//RayQuery: name=123
	//Scheme: http
	//Host: 127.0.0.1:2002
	targetQuery := target.RawQuery
	director := func(req *http.Request) {
		//url_rewrite
		//127.0.0.1:2002/dir/abc ==> 127.0.0.1:2003/base/abc ??
		//127.0.0.1:2002/dir/abc ==> 127.0.0.1:2002/abc
		//127.0.0.1:2002/abc ==> 127.0.0.1:2003/base/abc
		re, _ := regexp.Compile("^/dir(.*)");
		req.URL.Path = re.ReplaceAllString(req.URL.Path, "$1")

		req.URL.Scheme = target.Scheme
		req.URL.Host = target.Host

		//target.Path : /base
		//req.URL.Path : /dir
		req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
		if targetQuery == "" || req.URL.RawQuery == "" {
			req.URL.RawQuery = targetQuery + req.URL.RawQuery
		} else {
			req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
		}
		if _, ok := req.Header["User-Agent"]; !ok {
			req.Header.Set("User-Agent", "")
		}
	}
	modifyFunc := func(res *http.Response) error {
		if res.StatusCode != 200 {
			return errors.New("error statusCode")
			oldPayload, err := ioutil.ReadAll(res.Body)
			if err != nil {
				return err
			}
			newPayLoad := []byte("hello " + string(oldPayload))
			res.Body = ioutil.NopCloser(bytes.NewBuffer(newPayLoad))
			res.ContentLength = int64(len(newPayLoad))
			res.Header.Set("Content-Length", fmt.Sprint(len(newPayLoad)))
		}
		return nil
	}
	errorHandler := func(res http.ResponseWriter, req *http.Request, err error) {
		res.Write([]byte(err.Error()))
	}
	return &httputil.ReverseProxy{Director: director, ModifyResponse: modifyFunc, ErrorHandler: errorHandler}
}

func singleJoiningSlash(a, b string) string {
	aslash := strings.HasSuffix(a, "/")
	bslash := strings.HasPrefix(b, "/")
	switch {
	case aslash && bslash:
		return a + b[1:]
	case !aslash && !bslash:
		return a + "/" + b
	}
	return a + b
}

  

ReverseProxy补充知识:

特殊Header头:X-Forward-For、X-Real-Ip、Connection、TE、Trailer

第一代理取出标准的逐段传输头(HOP-BY-HOP)

X-Forward-For

  • 记录最后直连实际服务器之前,整个代理过程
  • 可能会被伪造

X-Real-Ip

  • 请求实际服务器的IP
  • 每过一层代理都会被覆盖掉,只需要第一代里设置转发
  • 不会被伪造

 代码实现:

package main

import (
    "bytes"
    "io/ioutil"
    "log"
    "math/rand"
    "net"
    "net/http"
    "net/http/httputil"
    "net/url"
    "regexp"
    "strconv"
    "strings"
    "time"
)

var addr = "127.0.0.1:2001"

func main() {
    rs1 := "http://127.0.0.1:2002"
    url1, err1 := url.Parse(rs1)
    if err1 != nil {
        log.Println(err1)
    }
    urls := []*url.URL{url1}
    proxy := NewMultipleHostsReverseProxy(urls)
    log.Println("Starting httpserver at " + addr)
    log.Fatal(http.ListenAndServe(addr, proxy))
}

var transport = &http.Transport{
    DialContext: (&net.Dialer{
        Timeout:   30 * time.Second, //连接超时
        KeepAlive: 30 * time.Second, //长连接超时时间
    }).DialContext,
    MaxIdleConns:          100,              //最大空闲连接
    IdleConnTimeout:       90 * time.Second, //空闲超时时间
    TLSHandshakeTimeout:   10 * time.Second, //tls握手超时时间
    ExpectContinueTimeout: 1 * time.Second,  //100-continue 超时时间
}

func NewMultipleHostsReverseProxy(targets []*url.URL) *httputil.ReverseProxy {
    //请求协调者
    director := func(req *http.Request) {
        //url_rewrite
        //127.0.0.1:2002/dir/abc ==> 127.0.0.1:2003/base/abc ??
        //127.0.0.1:2002/dir/abc ==> 127.0.0.1:2002/abc
        //127.0.0.1:2002/abc ==> 127.0.0.1:2003/base/abc
        re, _ := regexp.Compile("^/dir(.*)");
        req.URL.Path = re.ReplaceAllString(req.URL.Path, "$1")

        //随机负载均衡
        targetIndex := rand.Intn(len(targets))
        target := targets[targetIndex]
        targetQuery := target.RawQuery
        req.URL.Scheme = target.Scheme
        req.URL.Host = target.Host

        // url地址重写:重写前:/aa 重写后:/base/aa
        req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
        if targetQuery == "" || req.URL.RawQuery == "" {
            req.URL.RawQuery = targetQuery + req.URL.RawQuery
        } else {
            req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
        }
        if _, ok := req.Header["User-Agent"]; !ok {
            req.Header.Set("User-Agent", "user-agent")
        }
        //只在第一代理中设置此header头
        req.Header.Set("X-Real-Ip", req.RemoteAddr)
    }
    //更改内容
    modifyFunc := func(resp *http.Response) error {
        //请求以下命令:curl 'http://127.0.0.1:2002/error'
        if resp.StatusCode != 200 {
            //获取内容
            oldPayload, err := ioutil.ReadAll(resp.Body)
            if err != nil {
                return err
            }
            //追加内容
            newPayload := []byte("StatusCode error:" + string(oldPayload))
            resp.Body = ioutil.NopCloser(bytes.NewBuffer(newPayload))
            resp.ContentLength = int64(len(newPayload))
            resp.Header.Set("Content-Length", strconv.FormatInt(int64(len(newPayload)), 10))
        }
        return nil
    }
    //错误回调 :关闭real_server时测试,错误回调
    errFunc := func(w http.ResponseWriter, r *http.Request, err error) {
        http.Error(w, "ErrorHandler error:"+err.Error(), 500)
    }
    return &httputil.ReverseProxy{
        Director:       director,
        Transport:      transport,
        ModifyResponse: modifyFunc,
        ErrorHandler:   errFunc}
}

func singleJoiningSlash(a, b string) string {
    aslash := strings.HasSuffix(a, "/")
    bslash := strings.HasPrefix(b, "/")
    switch {
    case aslash && bslash:
        return a + b[1:]
    case !aslash && !bslash:
        return a + "/" + b
    }
    return a + b
}
第一层代理

第二层代理

package main

import (
    "bytes"
    "compress/gzip"
    "io/ioutil"
    "log"
    "math/rand"
    "net"
    "net/http"
    "net/http/httputil"
    "net/url"
    "regexp"
    "strconv"
    "strings"
    "time"
)

var addr = "127.0.0.1:2002"

func main() {
    //rs1 := "http://www.baidu.com"
    rs1 := "http://127.0.0.1:2003"
    url1, err1 := url.Parse(rs1)
    if err1 != nil {
        log.Println(err1)
    }

    //rs2 := "http://www.baidu.com"
    rs2 := "http://127.0.0.1:2004"
    url2, err2 := url.Parse(rs2)
    if err2 != nil {
        log.Println(err2)
    }
    urls := []*url.URL{url1, url2}
    proxy := NewMultipleHostsReverseProxy(urls)
    log.Println("Starting httpserver at " + addr)
    log.Fatal(http.ListenAndServe(addr, proxy))
}

var transport = &http.Transport{
    DialContext: (&net.Dialer{
        Timeout:   30 * time.Second, //连接超时
        KeepAlive: 30 * time.Second, //长连接超时时间
    }).DialContext,
    MaxIdleConns:          100,              //最大空闲连接
    IdleConnTimeout:       90 * time.Second, //空闲超时时间
    TLSHandshakeTimeout:   10 * time.Second, //tls握手超时时间
    ExpectContinueTimeout: 1 * time.Second,  //100-continue 超时时间
}

func NewMultipleHostsReverseProxy(targets []*url.URL) *httputil.ReverseProxy {
    //请求协调者
    director := func(req *http.Request) {
        //url_rewrite
        //127.0.0.1:2002/dir/abc ==> 127.0.0.1:2003/base/abc ??
        //127.0.0.1:2002/dir/abc ==> 127.0.0.1:2002/abc
        //127.0.0.1:2002/abc ==> 127.0.0.1:2003/base/abc
        re, _ := regexp.Compile("^/dir(.*)");
        req.URL.Path = re.ReplaceAllString(req.URL.Path, "$1")

        //随机负载均衡
        targetIndex := rand.Intn(len(targets))
        target := targets[targetIndex]
        targetQuery := target.RawQuery
        req.URL.Scheme = target.Scheme
        req.URL.Host = target.Host

        //todo 部分章节补充1
        //todo 当对域名(非内网)反向代理时需要设置此项。当作后端反向代理时不需要
        req.Host = target.Host

        // url地址重写:重写前:/aa 重写后:/base/aa
        req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
        if targetQuery == "" || req.URL.RawQuery == "" {
            req.URL.RawQuery = targetQuery + req.URL.RawQuery
        } else {
            req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
        }
        if _, ok := req.Header["User-Agent"]; !ok {
            req.Header.Set("User-Agent", "user-agent")
        }
        //只在第一代理中设置此header头
        //req.Header.Set("X-Real-Ip", req.RemoteAddr)
    }
    //更改内容
    modifyFunc := func(resp *http.Response) error {
        //请求以下命令:curl 'http://127.0.0.1:2002/error'
        //todo 部分章节功能补充2
        //todo 兼容websocket
        if strings.Contains(resp.Header.Get("Connection"), "Upgrade") {
            return nil
        }
        var payload []byte
        var readErr error

        //todo 部分章节功能补充3
        //todo 兼容gzip压缩
        if strings.Contains(resp.Header.Get("Content-Encoding"), "gzip") {
            gr, err := gzip.NewReader(resp.Body)
            if err != nil {
                return err
            }
            payload, readErr = ioutil.ReadAll(gr)
            resp.Header.Del("Content-Encoding")
        } else {
            payload, readErr = ioutil.ReadAll(resp.Body)
        }
        if readErr != nil {
            return readErr
        }

        //异常请求时设置StatusCode
        if resp.StatusCode != 200 {
            payload = []byte("StatusCode error:" + string(payload))
        }

        //todo 部分章节功能补充4
        //todo 因为预读了数据所以内容重新回写
        resp.Body = ioutil.NopCloser(bytes.NewBuffer(payload))
        resp.ContentLength = int64(len(payload))
        resp.Header.Set("Content-Length", strconv.FormatInt(int64(len(payload)), 10))
        return nil
    }
    //错误回调 :关闭real_server时测试,错误回调
    errFunc := func(w http.ResponseWriter, r *http.Request, err error) {
        http.Error(w, "ErrorHandler error:"+err.Error(), 500)
    }
    return &httputil.ReverseProxy{
        Director:       director,
        Transport:      transport,
        ModifyResponse: modifyFunc,
        ErrorHandler:   errFunc}
}

func singleJoiningSlash(a, b string) string {
    aslash := strings.HasSuffix(a, "/")
    bslash := strings.HasPrefix(b, "/")
    switch {
    case aslash && bslash:
        return a + b[1:]
    case !aslash && !bslash:
        return a + "/" + b
    }
    return a + b
}
View Code

实际服务器:

package main

import (
    "fmt"
    "io"
    "log"
    "net/http"
    "os"
    "os/signal"
    "syscall"
    "time"
)

func main() {
    rs1 := &RealServer{Addr: "127.0.0.1:2003"}
    rs1.Run()
    rs2 := &RealServer{Addr: "127.0.0.1:2004"}
    rs2.Run()

    //监听关闭信号
    quit := make(chan os.Signal)
    signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
    <-quit
}

type RealServer struct {
    Addr string
}

func (r *RealServer) Run() {
    log.Println("Starting httpserver at " + r.Addr)
    mux := http.NewServeMux()
    mux.HandleFunc("/", r.HelloHandler)
    mux.HandleFunc("/base/error", r.ErrorHandler)
    server := &http.Server{
        Addr:         r.Addr,
        WriteTimeout: time.Second * 3,
        Handler:      mux,
    }
    go func() {
        log.Fatal(server.ListenAndServe())
    }()
}

func (r *RealServer) HelloHandler(w http.ResponseWriter, req *http.Request) {
    //127.0.0.1:8008/abc?sdsdsa=11
    //r.Addr=127.0.0.1:8008
    //req.URL.Path=/abc
    fmt.Println(req.Host)
    upath := fmt.Sprintf("http://%s%s
", r.Addr, req.URL.Path)
    realIP := fmt.Sprintf("RemoteAddr=%s,X-Forwarded-For=%v,X-Real-Ip=%v
", req.RemoteAddr, req.Header.Get("X-Forwarded-For"), req.Header.Get("X-Real-Ip"))

    io.WriteString(w, upath)
    io.WriteString(w, realIP)
}

func (r *RealServer) ErrorHandler(w http.ResponseWriter, req *http.Request) {
    upath := "error handler"
    w.WriteHeader(500)
    io.WriteString(w, upath)
}
View Code

负载均衡策略:

  • 随机负载
  •   随机挑选目标服务器ip
  • 轮询负载
  •   ABC三台服务器,ABCABC一次轮询
  • 加权负载
  •   给目标设置访问权重,按照权重轮询
  • 一致性hash负载
  •   请求固定的url访问固定的ip

随机负载:

package load_balance

import (
    "errors"
    "fmt"
    "math/rand"
    "strings"
)

type RandomBalance struct {
    curIndex int
    rss      []string
    //观察主体
    conf LoadBalanceConf
}

func (r *RandomBalance) Add(params ...string) error {
    if len(params) == 0 {
        return errors.New("param len 1 at least")
    }
    addr := params[0]
    r.rss = append(r.rss, addr)
    return nil
}

func (r *RandomBalance) Next() string {
    if len(r.rss) == 0 {
        return ""
    }
    r.curIndex = rand.Intn(len(r.rss))
    return r.rss[r.curIndex]
}

func (r *RandomBalance) Get(key string) (string, error) {
    return r.Next(), nil
}

func (r *RandomBalance) SetConf(conf LoadBalanceConf) {
    r.conf = conf
}

func (r *RandomBalance) Update() {
    if conf, ok := r.conf.(*LoadBalanceZkConf); ok {
        fmt.Println("Update get conf:", conf.GetConf())
        r.rss = []string{}
        for _, ip := range conf.GetConf() {
            r.Add(strings.Split(ip, ",")...)
        }
    }
    if conf, ok := r.conf.(*LoadBalanceCheckConf); ok {
        fmt.Println("Update get conf:", conf.GetConf())
        r.rss = nil
        for _, ip := range conf.GetConf() {
            r.Add(strings.Split(ip, ",")...)
        }
    }
}
random.go
package load_balance

import (
    "fmt"
    "testing"
)

func TestRandomBalance(t *testing.T) {
    rb := &RandomBalance{}
    rb.Add("127.0.0.1:2003") //0
    rb.Add("127.0.0.1:2004") //1
    rb.Add("127.0.0.1:2005") //2
    rb.Add("127.0.0.1:2006") //3
    rb.Add("127.0.0.1:2007") //4

    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
}
random_test
=== RUN   TestRandomBalance
127.0.0.1:2004
127.0.0.1:2005
127.0.0.1:2005
127.0.0.1:2007
127.0.0.1:2004
127.0.0.1:2006
127.0.0.1:2003
127.0.0.1:2003
127.0.0.1:2004
--- PASS: TestRandomBalance (0.00s)
PASS

轮询负载:

package load_balance

import (
    "errors"
    "fmt"
    "strings"
)

type RoundRobinBalance struct {
    curIndex int
    rss      []string
    //观察主体
    conf LoadBalanceConf
}

func (r *RoundRobinBalance) Add(params ...string) error {
    if len(params) == 0 {
        return errors.New("param len 1 at least")
    }
    addr := params[0]
    r.rss = append(r.rss, addr)
    return nil
}

func (r *RoundRobinBalance) Next() string {
    if len(r.rss) == 0 {
        return ""
    }
    lens := len(r.rss) //5
    if r.curIndex >= lens {
        r.curIndex = 0
    }
    curAddr := r.rss[r.curIndex]
    r.curIndex = (r.curIndex + 1) % lens
    return curAddr
}

func (r *RoundRobinBalance) Get(key string) (string, error) {
    return r.Next(), nil
}

func (r *RoundRobinBalance) SetConf(conf LoadBalanceConf) {
    r.conf = conf
}

func (r *RoundRobinBalance) Update() {
    if conf, ok := r.conf.(*LoadBalanceZkConf); ok {
        fmt.Println("Update get conf:", conf.GetConf())
        r.rss = []string{}
        for _, ip := range conf.GetConf() {
            r.Add(strings.Split(ip, ",")...)
        }
    }
    if conf, ok := r.conf.(*LoadBalanceCheckConf); ok {
        fmt.Println("Update get conf:", conf.GetConf())
        r.rss = nil
        for _, ip := range conf.GetConf() {
            r.Add(strings.Split(ip, ",")...)
        }
    }
}
round_tobin
package load_balance

import (
    "fmt"
    "testing"
)

func Test_main(t *testing.T) {
    rb := &RoundRobinBalance{}
    rb.Add("127.0.0.1:2003") //0
    rb.Add("127.0.0.1:2004") //1
    rb.Add("127.0.0.1:2005") //2
    rb.Add("127.0.0.1:2006") //3
    rb.Add("127.0.0.1:2007") //4

    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
}
round_robin_test
=== RUN   Test_main
127.0.0.1:2003
127.0.0.1:2004
127.0.0.1:2005
127.0.0.1:2006
127.0.0.1:2007
127.0.0.1:2003
127.0.0.1:2004
127.0.0.1:2005
127.0.0.1:2006
--- PASS: Test_main (0.00s)
PASS

加权负载均衡:

  • Weight
  • 初始化时对节点约定的权重
  • currentWeight
  • 节点临时权重,每轮都会变化
  • effectiveWeight
  • 节点有效权重,默认与Weight相同
  • totalWeight
  • 所有节点有效权重之和:sum(effectiveWeight)
type WeightNode struct {
    addr            string
    weight          int //权重值
    currentWeight   int //节点当前权重
    effectiveWeight int //有效权重
}
  • 1,currentWeight = currentWeight + effectiveWeight
  • 2,选中最大的currentWeight节点为选中的节点
  • 3,currentWeight = currentWeight-totalWeight(4+3+2=9)

 计算方法如下:

第一次:

  •   currentWeight = currentWeight + effectiveWeight
  •     currentWeight     {A=4+4,B=3+3,C=2+2}   ==  {A=8,B=6,C=4}
  •   选中最大的currentWeight节点为选中的节点
  •     A最大 此时作为节点
  •   currentWeight = currentWeight-totalWeight(4+3+2=9)  【选中的节点currentWeight = currentWeight-totalWeight】
  •     currentWeight  {A=8-9,B=6,C=4}  == {A=-1,B=6,C=4}

第二次:{A=-1,B=6,C=4} 开始

  •   currentWeight = currentWeight + effectiveWeight
  •     currentWeight     {A=-1+4,B=6+3,C=4+2}   ==  {A=3,B=9,C=6}
  •   选中最大的currentWeight节点为选中的节点
  •     B最大 此时作为节点
  •   选中的节点currentWeight = currentWeight-totalWeight(4+3+2=9)
  •     currentWeight  {A=3,B=9-9,C=6}  == {A=3,B=0,C=6}

。。。。。。。以此类推。。。。。。。。。

package load_balance

import (
    "errors"
    "fmt"
    "strconv"
    "strings"
)

type WeightRoundRobinBalance struct {
    curIndex int
    rss      []*WeightNode
    rsw      []int
    //观察主体
    conf LoadBalanceConf
}

type WeightNode struct {
    addr            string
    weight          int //权重值
    currentWeight   int //节点当前权重
    effectiveWeight int //有效权重
}

func (r *WeightRoundRobinBalance) Add(params ...string) error {
    if len(params) != 2 {
        return errors.New("param len need 2")
    }
    parInt, err := strconv.ParseInt(params[1], 10, 64)
    if err != nil {
        return err
    }
    node := &WeightNode{addr: params[0], weight: int(parInt)}
    node.effectiveWeight = node.weight
    r.rss = append(r.rss, node)
    return nil
}

func (r *WeightRoundRobinBalance) Next() string {
    total := 0
    var best *WeightNode
    for i := 0; i < len(r.rss); i++ {
        w := r.rss[i]
        //step 1 统计所有有效权重之和
        total += w.effectiveWeight

        //step 2 变更节点临时权重为的节点临时权重+节点有效权重
        w.currentWeight += w.effectiveWeight

        //step 3 有效权重默认与权重相同,通讯异常时-1, 通讯成功+1,直到恢复到weight大小
        if w.effectiveWeight < w.weight {
            w.effectiveWeight++
        }
        //step 4 选择最大临时权重点节点
        if best == nil || w.currentWeight > best.currentWeight {
            best = w
        }
    }
    if best == nil {
        return ""
    }
    //step 5 变更临时权重为 临时权重-有效权重之和
    best.currentWeight -= total
    return best.addr
}

func (r *WeightRoundRobinBalance) Get(key string) (string, error) {
    return r.Next(), nil
}

func (r *WeightRoundRobinBalance) SetConf(conf LoadBalanceConf) {
    r.conf = conf
}

func (r *WeightRoundRobinBalance) Update() {
    if conf, ok := r.conf.(*LoadBalanceZkConf); ok {
        fmt.Println("WeightRoundRobinBalance get conf:", conf.GetConf())
        r.rss = nil
        for _, ip := range conf.GetConf() {
            r.Add(strings.Split(ip, ",")...)
        }
    }
    if conf, ok := r.conf.(*LoadBalanceCheckConf); ok {
        fmt.Println("WeightRoundRobinBalance get conf:", conf.GetConf())
        r.rss = nil
        for _, ip := range conf.GetConf() {
            r.Add(strings.Split(ip, ",")...)
        }
    }
}
weight_tound_robin.go
package load_balance

import (
    "fmt"
    "testing"
)

func TestLB(t *testing.T) {
    rb := &WeightRoundRobinBalance{}
    rb.Add("127.0.0.1:2003", "4") //0
    rb.Add("127.0.0.1:2004", "3") //1
    rb.Add("127.0.0.1:2005", "2") //2

    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
    fmt.Println(rb.Next())
}
test

一致性hash(ip_hash、url_hash)

为了解决平衡性:引入了虚拟节点概念(把个节点 均匀的覆盖到环上)

package load_balance

import (
    "errors"
    "fmt"
    "hash/crc32"
    "sort"
    "strconv"
    "strings"
    "sync"
)

type Hash func(data []byte) uint32

type UInt32Slice []uint32

func (s UInt32Slice) Len() int {
    return len(s)
}

func (s UInt32Slice) Less(i, j int) bool {
    return s[i] < s[j]
}

func (s UInt32Slice) Swap(i, j int) {
    s[i], s[j] = s[j], s[i]
}

type ConsistentHashBanlance struct {
    mux      sync.RWMutex
    hash     Hash
    replicas int               //复制因子 虚拟节点数
    keys     UInt32Slice       //已排序的节点hash切片 映射在环上的虚拟节点
    hashMap  map[uint32]string //节点哈希和Key的map,键是hash值,值是节点key

    //观察主体
    conf LoadBalanceConf
}

func NewConsistentHashBanlance(replicas int, fn Hash) *ConsistentHashBanlance {
    m := &ConsistentHashBanlance{
        replicas: replicas,//复制因子 虚拟节点数
        hash:     fn,
        hashMap:  make(map[uint32]string),
    }
    if m.hash == nil {
        //最多32位,保证是一个2^32-1环
        m.hash = crc32.ChecksumIEEE
    }
    return m
}

// 验证是否为空
func (c *ConsistentHashBanlance) IsEmpty() bool {
    return len(c.keys) == 0
}

// Add 方法用来添加缓存节点,参数为节点key,比如使用IP
func (c *ConsistentHashBanlance) Add(params ...string) error {
    if len(params) == 0 {
        return errors.New("param len 1 at least")
    }
    addr := params[0]
    c.mux.Lock()
    defer c.mux.Unlock()
    // 结合复制因子计算所有虚拟节点的hash值,并存入m.keys中,同时在m.hashMap中保存哈希值和key的映射
    for i := 0; i < c.replicas; i++ {
        hash := c.hash([]byte(strconv.Itoa(i) + addr))
        c.keys = append(c.keys, hash)
        c.hashMap[hash] = addr
    }
    // 对所有虚拟节点的哈希值进行排序,方便之后进行二分查找
    sort.Sort(c.keys)
    return nil
}

// Get 方法根据给定的对象获取最靠近它的那个节点
func (c *ConsistentHashBanlance) Get(key string) (string, error) {
    if c.IsEmpty() {
        return "", errors.New("node is empty")
    }
    hash := c.hash([]byte(key))

    // 通过二分查找获取最优节点,第一个"服务器hash"值大于"数据hash"值的就是最优"服务器节点"
    idx := sort.Search(len(c.keys), func(i int) bool { return c.keys[i] >= hash })

    // 如果查找结果 大于 服务器节点哈希数组的最大索引,表示此时该对象哈希值位于最后一个节点之后,那么放入第一个节点中
    if idx == len(c.keys) {
        idx = 0
    }
    c.mux.RLock()
    defer c.mux.RUnlock()
    return c.hashMap[c.keys[idx]], nil
}

func (c *ConsistentHashBanlance) SetConf(conf LoadBalanceConf) {
    c.conf = conf
}

func (c *ConsistentHashBanlance) Update() {
    if conf, ok := c.conf.(*LoadBalanceZkConf); ok {
        fmt.Println("Update get conf:", conf.GetConf())
        c.mux.Lock()
        defer c.mux.Unlock()
        c.keys = nil
        c.hashMap = nil
        for _, ip := range conf.GetConf() {
            c.Add(strings.Split(ip, ",")...)
        }
    }
    if conf, ok := c.conf.(*LoadBalanceCheckConf); ok {
        fmt.Println("Update get conf:", conf.GetConf())
        c.mux.Lock()
        defer c.mux.Unlock()
        c.keys = nil
        c.hashMap = nil
        for _, ip := range conf.GetConf() {
            c.Add(strings.Split(ip, ",")...)
        }
    }
}
consistent_hash.go
package load_balance

import (
    "fmt"
    "testing"
)

func TestNewConsistentHashBanlance(t *testing.T) {
    rb := NewConsistentHashBanlance(10, nil)
    rb.Add("127.0.0.1:2003") //0
    rb.Add("127.0.0.1:2004") //1
    rb.Add("127.0.0.1:2005") //2
    rb.Add("127.0.0.1:2006") //3
    rb.Add("127.0.0.1:2007") //4

    //url hash
    fmt.Println(rb.Get("http://127.0.0.1:2002/base/getinfo"))
    fmt.Println(rb.Get("http://127.0.0.1:2002/base/error"))
    fmt.Println(rb.Get("http://127.0.0.1:2002/base/getinfo"))
    fmt.Println(rb.Get("http://127.0.0.1:2002/base/changepwd"))

    //ip hash
    fmt.Println(rb.Get("127.0.0.1"))
    fmt.Println(rb.Get("192.168.0.1"))
    fmt.Println(rb.Get("127.0.0.1"))
}
test.go

工厂方法简单封装上述几种拒载均衡调用:

interface.go

package load_balance

type LoadBalance interface {
    Add(...string) error
    Get(string) (string, error)

    //后期服务发现补充
    Update()
}

factory.go

package load_balance

type LbType int

const (
    LbRandom LbType = iota
    LbRoundRobin
    LbWeightRoundRobin
    LbConsistentHash
)

func LoadBanlanceFactory(lbType LbType) LoadBalance {
    switch lbType {
    case LbRandom:
        return &RandomBalance{}
    case LbConsistentHash:
        return NewConsistentHashBanlance(10, nil)
    case LbRoundRobin:
        return &RoundRobinBalance{}
    case LbWeightRoundRobin:
        return &WeightRoundRobinBalance{}
    default:
        return &RandomBalance{}
    }
}

调用:

func main() {
    rb := load_balance.LoadBanlanceFactory(load_balance.LbWeightRoundRobin)
    if err := rb.Add("http://127.0.0.1:2003/base", "10"); err != nil {
        log.Println(err)
    }
    if err := rb.Add("http://127.0.0.1:2004/base", "20"); err != nil {
        log.Println(err)
    }
   // 。。。。。。。。。。。。。。。
}
原文地址:https://www.cnblogs.com/sunlong88/p/13512362.html