Go语言之sync包 WaitGroup的使用

WaitGroup 是什么以及它能为我们解决什么问题?

WaitGroup在go语言中,用于线程同步,单从字面意思理解,wait等待的意思,group组、团队的意思,WaitGroup就是指等待一组,等待一个系列执行完成后才会继续向下执行。

正常情况下,goroutine的结束过程是不可控制的,我们可以保证的只有main goroutine的终止。

这时候可以借助sync包的WaitGroup来判断goroutine是否完成。

WaitGroup介绍

WatiGroupsync包中的一个struct类型,用来收集需要等待执行完成的goroutine。下面是它的定义:

// WaitGroup用于等待一组线程的结束。
// 父线程调用Add方法来设定应等待的线程的数量。
// 每个被等待的线程在结束时应调用Done方法。同时,主线程里可以调用Wait方法阻塞至所有线程结束。
type WaitGroup struct {
    // 包含隐藏或非导出字段
}

// Add方法向内部计数加上delta,delta可以是负数;
// 如果内部计数器变为0,Wait方法阻塞等待的所有线程都会释放,如果计数器小于0,方法panic。
// 注意Add加上正数的调用应在Wait之前,否则Wait可能只会等待很少的线程。
// 一般来说本方法应在创建新的线程或者其他应等待的事件之前调用。
func (wg *WaitGroup) Add(delta int)

// Done方法减少WaitGroup计数器的值,应在线程的最后执行。
func (wg *WaitGroup) Done()

// Wait方法阻塞直到WaitGroup计数器减为0。
func (wg *WaitGroup) Wait()

它有3个方法:

    Add():每次激活想要被等待完成的goroutine之前,先调用Add(),用来设置或添加要等待完成的goroutine数量

        例如Add(2)或者两次调用Add(1)都会设置等待计数器的值为2,表示要等待2个goroutine完成

    Done():每次需要等待的goroutine在真正完成之前,应该调用该方法来人为表示goroutine完成了,该方法会对等待计数器减1

    Wait():在等待计数器减为0之前,Wait()会一直阻塞当前的goroutine

    也就是说,Add()用来增加要等待的goroutine的数量,Done()用来表示goroutine已经完成了,减少一次计数器,Wait()用来等待所有需要等待的goroutine完成。

示例一

package main

import (
    "fmt"
    "sync"
    "time"
)

// 每个协程都会运行该函数。
// 注意,WaitGroup 必须通过指针传递给函数。
func worker(id int, wg *sync.WaitGroup) {
    fmt.Printf("Worker %d starting
", id)

    // 睡眠一秒钟,以此来模拟耗时的任务。
    time.Sleep(time.Second)
    fmt.Printf("Worker %d done
", id)

    // 通知 WaitGroup ,当前协程的工作已经完成。
    wg.Done()
}

func main() {

    // 这个 WaitGroup 被用于等待该函数开启的所有协程。
    var wg sync.WaitGroup

    // 开启几个协程,并为其递增 WaitGroup 的计数器。
    for i := 1; i <= 5; i++ {
        wg.Add(1)
        go worker(i, &wg)
    }

    // 阻塞,直到 WaitGroup 计数器恢复为 0,即所有协程的工作都已经完成。
    wg.Wait()
}

main中开启了5个协程,开启协程之前都先调用了Add()方法增加了一个需要等待goroutine计数。每个goroutine都运行worker()函数,这个函数执行完成后调用Done()方法通知 WaitGroup表示当前协程的完成。

有一点需要注意,worker()函数中使用了指针类型的*sync.WaitGroup作为参数,这里不能使用值类型的sync.WaitGroup作为参数,因为这意味着每个goroutine都拷贝一份wg,每个goroutine都使用自己的wg。这显然是不合理的,这5个协程应该共享一个wg,这样才能知道这几个协程都完成了。实际上,如果使用值类型的参数,main goroutine将会永久阻塞而导致产生死锁。

还有一点需要注意AddDone函数一定要配对,否则可能发生死锁,所报的错误信息如下:

fatal error: all goroutines are asleep - deadlock!

运行:

go run waitgroups.go
Worker 5 starting
Worker 3 starting
Worker 4 starting
Worker 1 starting
Worker 2 starting
Worker 4 done
Worker 1 done
Worker 2 done
Worker 5 done
Worker 3 done

每次运行,各个协程开启和完成的时间可能是不同的。

示例二

在工作中使用时,等待一个协程组全部正确完成则结束;但其中一个协程发生错误,这时候就会阻塞了,不推荐这种用法。

这种场景就需要使用到通知机制,这时候可以使用channel来实现。

package main

import (
	"fmt"
	"sync"
	"time"
)


func main(){
	// 这个 WaitGroup 被用于等待该函数开启的所有协程。
	var wg sync.WaitGroup

	// Add()方法开启了3个等待的协程计数
	wg.Add(3)

        // 开启3个协程,用于工作处理
	go work1(&wg)
	go work2(&wg)
	go work3(&wg)

	// 阻塞,直到 WaitGroup 计数器恢复为 0,即所有协程的工作都已经完成。
	wg.Wait()
}

func work1(wg *sync.WaitGroup){
	fmt.Println("work1 starting")

	// 睡眠一秒钟,以此来模拟耗时的任务。
	time.Sleep(time.Second)
	fmt.Println("work1 done")

	// 通知 WaitGroup ,当前协程的工作已经完成。
	wg.Done()
}

func work2(wg *sync.WaitGroup){
	fmt.Println("work2 starting")

	// 睡眠一秒钟,以此来模拟耗时的任务。
	time.Sleep(time.Second)
	fmt.Println("work2 done")

	// 通知 WaitGroup ,当前协程的工作已经完成。
	wg.Done()
}

func work3(wg *sync.WaitGroup){
	fmt.Println("work3 starting")

	// 睡眠一秒钟,以此来模拟耗时的任务。
	time.Sleep(time.Second)
	fmt.Println("work3 done")

	// 通知 WaitGroup ,当前协程的工作已经完成。
	wg.Done()
}

源码分析

type WaitGroup struct {
	noCopy noCopy

	// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
	// 64-bit atomic operations require 64-bit alignment, but 32-bit
	// compilers do not ensure it. So we allocate 12 bytes and then use
	// the aligned 8 bytes in them as state, and the other 4 as storage
	// for the sema.
	state1 [3]uint32
}

WaitGroup 结构十分简单,由 nocopystate1 两个字段组成,其中 nocopy 是用来防止复制的

type noCopy struct{}

// Lock is a no-op used by -copylocks checker from `go vet`.
func (*noCopy) Lock()   {}
func (*noCopy) Unlock() {}

由于嵌入了 nocopy 所以在执行 go vet 时如果检查到 WaitGroup 被复制了就会报错。这样可以一定程度上保证 WaitGroup 不被复制,对了直接 go run 是不会有错误的,所以我们代码 push 之前都会强制要求进行 lint 检查,在 ci/cd 阶段也需要先进行 lint 检查,避免出现这种类似的错误。

~/project/Go-000/Week03/blog/06_waitgroup/02 main*
❯ go run ./main.go

~/project/Go-000/Week03/blog/06_waitgroup/02 main*
❯ go vet .
# github.com/mohuishou/go-training/Week03/blog/06_waitgroup/02
./main.go:7:9: assignment copies lock value to wg2: sync.WaitGroup contains sync.noCopy

state1 的设计非常巧妙,这是一个是十二字节的数据,这里面主要包含两大块,counter 占用了 8 字节用于计数,sema 占用 4 字节用做信号量

为什么要这么搞呢?直接用两个字段一个表示 counter,一个表示 sema 不行么?
不行,我们看看注释里面怎么写的。

// 64-bit value: high 32 bits are counter, low 32 bits are waiter count. > // 64-bit atomic operations require 64-bit alignment, but 32-bit > // compilers do not ensure it. So we allocate 12 bytes and then use > // the aligned 8 bytes in them as state, and the other 4 as storage > // for the sema.

这段话的关键点在于,在做 64 位的原子操作的时候必须要保证 64 位(8 字节)对齐,如果没有对齐的就会有问题,但是 32 位的编译器并不能保证 64 位对齐所以这里用一个 12 字节的 state1 字段来存储这两个状态,然后根据是否 8 字节对齐选择不同的保存方式。

这个操作巧妙在哪里呢?

  • 如果是 64 位的机器那肯定是 8 字节对齐了的,所以是上面第一种方式
  • 如果在 32 位的机器上
    如果恰好 8 字节对齐了,那么也是第一种方式取前面的 8 字节数据
    如果是没有对齐,但是 32 位 4 字节是对齐了的,所以我们只需要后移四个字节,那么就 8 字节对齐了,所以是第二种方式

所以通过 sema 信号量这四个字节的位置不同,保证了 counter 这个字段无论在 32 位还是 64 为机器上都是 8 字节对齐的,后续做 64 位原子操作的时候就没问题了。

这个实现是在 state 方法实现的

func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
	if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
		return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
	} else {
		return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
	}
}

state 方法返回 counter 和信号量,通过 uintptr(unsafe.Pointer(&wg.state1))%8 == 0 来判断是否 8 字节对齐

Add

func (wg *WaitGroup) Add(delta int) {
    // 先从 state 当中把数据和信号量取出来
	statep, semap := wg.state()

    // 在 waiter 上加上 delta 值
	state := atomic.AddUint64(statep, uint64(delta)<<32)
    // 取出当前的 counter
	v := int32(state >> 32)
    // 取出当前的 waiter,正在等待 goroutine 数量
	w := uint32(state)

    // counter 不能为负数
	if v < 0 {
		panic("sync: negative WaitGroup counter")
	}

    // 这里属于防御性编程
    // w != 0 说明现在已经有 goroutine 在等待中,说明已经调用了 Wait() 方法
    // 这时候 delta > 0 && v == int32(delta) 说明在调用了 Wait() 方法之后又想加入新的等待者
    // 这种操作是不允许的
	if w != 0 && delta > 0 && v == int32(delta) {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
    // 如果当前没有人在等待就直接返回,并且 counter > 0
	if v > 0 || w == 0 {
		return
	}

    // 这里也是防御 主要避免并发调用 add 和 wait
	if *statep != state {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}

	// 唤醒所有 waiter,看到这里就回答了上面的问题了
	*statep = 0
	for ; w != 0; w-- {
		runtime_Semrelease(semap, false, 0)
	}
}

Wait

wait 主要就是等待其他的 goroutine 完事之后唤醒

func (wg *WaitGroup) Wait() {
	// 先从 state 当中把数据和信号量的地址取出来
    statep, semap := wg.state()

	for {
     	// 这里去除 counter 和 waiter 的数据
		state := atomic.LoadUint64(statep)
		v := int32(state >> 32)
		w := uint32(state)

        // counter = 0 说明没有在等的,直接返回就行
        if v == 0 {
			// Counter is 0, no need to wait.
			return
		}

		// waiter + 1,调用一次就多一个等待者,然后休眠当前 goroutine 等待被唤醒
		if atomic.CompareAndSwapUint64(statep, state, state+1) {
			runtime_Semacquire(semap)
			if *statep != 0 {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			return
		}
	}
}

Done

func (wg *WaitGroup) Done() {
	wg.Add(-1)
}

总结

通过WaitGroup提供的三个函数:Add,Done,Wait,可以轻松实现等待某个协程或协程组完成的同步操作。但在使用时要注意:

  • WaitGroup 可以用于一个 goroutine 等待多个 goroutine 干活完成,也可以多个 goroutine 等待一个 goroutine 干活完成,是一个多对多的关系
    多个等待一个的典型案例是 singleflight,这个在后面将微服务可用性的时候还会再讲到,感兴趣可以看看源码
  • Add(n>0) 方法应该在启动 goroutine 之前调用,然后在 goroution 内部调用 Done 方法
  • WaitGroup 必须在 Wait 方法返回之后才能再次使用
  • Done 只是 Add 的简单封装,所以实际上是可以通过一次加一个比较大的值减少调用,或者达到快速唤醒的目的。
  • 协程函数要使用指针类型的*sync.WaitGroup作为参数,不能使用值类型的sync.WaitGroup作为参数
  • Add的数量和Done的调用数量必须相等,否则可能发生死锁

WaitGroup在需要等待多个任务结束再返回的业务来说还是很有用的,但现实中用的更多的可能是,先等待一个协程组,若所有协程组都正确完成,则一直等到所有协程组结束;若其中有一个协程发生错误,则告诉协程组的其他协程,全部停止运行(本次任务失败)以免浪费系统资源。

该场景WaitGroup是无法实现的,那么该场景该如何实现呢,就需要用到通知机制,其实也可以用channel来实现,具体的解决办法,请看后续的文章。

这样说来,WaitGroup的使用场景是有限的。

原文地址:https://www.cnblogs.com/niuben/p/14415196.html