sync.WaitGroup

sync.WaitGroupsync.ErrGroupGo 语言中用于处理并发任务的两个重要工具。它们都可以帮助我们管理多个 goroutine 的执行和同步,但它们的使用场景和功能略有不同。其中 sync.ErrGroup 主要用于处理并发任务中的错误。且其是基于 sync.WaitGroup 实现的,所以先了解 sync.WaitGroup

sync.WaitGroupGo 语言中用于等待一组 goroutine 完成的同步原语。它提供了一个简单的方式来协调多个 goroutine 的执行,确保在主程序退出之前所有的 goroutine 都已经完成。

sync.WaitGroup 的主要方法有:

  • Add(delta int):增加等待的 goroutine 数量。通常在启动新的 goroutine 之前调用。
  • Done():表示一个 goroutine 完成。通常在 goroutine 的最后调用。
  • Wait():阻塞直到所有的 goroutine 都完成。通常在主程序中调用,等待所有的 goroutine 完成。

使用示例

下面是一个使用 sync.WaitGroup 的简单示例:

1
2
3
4
5
6
7
8
9
10
11
var wg sync.WaitGroup

for i := 0; i < 5; i++ {
wg.Add(1) // 增加等待的 goroutine 数量
go func(i int) {
defer wg.Done() // 在 goroutine 完成时调用 Done()
fmt.Println("Hello from goroutine", i)
}(i)
}

wg.Wait() // 等待所有 goroutine 完成

这个示例启动了 5 个 goroutine,每个 goroutine 在完成时调用 wg.Done(),而主程序在调用 wg.Wait() 时会阻塞,直到所有的 goroutine 都完成。也是平时使用 sync.WaitGroup 的常见模式。

需要注意的是:

  • 不需要初始化 sync.WaitGroup,它的零值是有效的。
  • Add 方法可以在任何时候调用,但通常在启动新的 goroutine 之前调用。
  • Done 方法可以在任何时候调用,但通常在 goroutine 的最后调用。
  • Wait 方法会阻塞,直到所有的 goroutine 都完成。

当然这还会有一些问题:

  • 在协程中同时调用了 AddDoneWait 方法,当其并发执行的时候会导致死锁或数据竞争等问题吗?
  • 如果 Done 方法在 Wait 方法之前调用会怎么样?

第二个问题其实可以通过写代码来验证,运行的时候会出现一个 panic 错误。而第一个问题,以及第二个问题的 panic 错误是怎么引起的,就需要看 sync.WaitGroup 的源码了。

sync.WaitGroup 源码阅读

Go 1.24.3 版本的 sync.WaitGroup 源码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
// A WaitGroup waits for a collection of goroutines to finish.
// The main goroutine calls [WaitGroup.Add] to set the number of
// goroutines to wait for. Then each of the goroutines
// runs and calls [WaitGroup.Done] when finished. At the same time,
// [WaitGroup.Wait] can be used to block until all goroutines have finished.
//
// A WaitGroup must not be copied after first use.
//
// In the terminology of [the Go memory model], a call to [WaitGroup.Done]
// “synchronizes before” the return of any Wait call that it unblocks.
//
// [the Go memory model]: https://go.dev/ref/mem
type WaitGroup struct {
noCopy noCopy

state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
sema uint32
}

// Add adds delta, which may be negative, to the [WaitGroup] counter.
// If the counter becomes zero, all goroutines blocked on [WaitGroup.Wait] are released.
// If the counter goes negative, Add panics.
//
// Note that calls with a positive delta that occur when the counter is zero
// must happen before a Wait. Calls with a negative delta, or calls with a
// positive delta that start when the counter is greater than zero, may happen
// at any time.
// Typically this means the calls to Add should execute before the statement
// creating the goroutine or other event to be waited for.
// If a WaitGroup is reused to wait for several independent sets of events,
// new Add calls must happen after all previous Wait calls have returned.
// See the WaitGroup example.
func (wg *WaitGroup) Add(delta int) {
if race.Enabled {
if delta < 0 {
// Synchronize decrements with Wait.
race.ReleaseMerge(unsafe.Pointer(wg))
}
race.Disable()
defer race.Enable()
}
state := wg.state.Add(uint64(delta) << 32)
v := int32(state >> 32)
w := uint32(state)
if race.Enabled && delta > 0 && v == int32(delta) {
// The first increment must be synchronized with Wait.
// Need to model this as a read, because there can be
// several concurrent wg.counter transitions from 0.
race.Read(unsafe.Pointer(&wg.sema))
}
if v < 0 {
panic("sync: negative WaitGroup counter")
}
if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
if v > 0 || w == 0 {
return
}
// This goroutine has set counter to 0 when waiters > 0.
// Now there can't be concurrent mutations of state:
// - Adds must not happen concurrently with Wait,
// - Wait does not increment waiters if it sees counter == 0.
// Still do a cheap sanity check to detect WaitGroup misuse.
if wg.state.Load() != state {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// Reset waiters count to 0.
wg.state.Store(0)
for ; w != 0; w-- {
runtime_Semrelease(&wg.sema, false, 0)
}
}

// Done decrements the [WaitGroup] counter by one.
func (wg *WaitGroup) Done() {
wg.Add(-1)
}

// Wait blocks until the [WaitGroup] counter is zero.
func (wg *WaitGroup) Wait() {
if race.Enabled {
race.Disable()
}
for {
state := wg.state.Load()
v := int32(state >> 32)
w := uint32(state)
if v == 0 {
// Counter is 0, no need to wait.
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(wg))
}
return
}
// Increment waiters count.
if wg.state.CompareAndSwap(state, state+1) {
if race.Enabled && w == 0 {
// Wait must be synchronized with the first Add.
// Need to model this is as a write to race with the read in Add.
// As a consequence, can do the write only for the first waiter,
// otherwise concurrent Waits will race with each other.
race.Write(unsafe.Pointer(&wg.sema))
}
runtime_SemacquireWaitGroup(&wg.sema)
if wg.state.Load() != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(wg))
}
return
}
}
}

sync.WaitGroup 结构体

1
2
3
4
5
6
type WaitGroup struct {
noCopy noCopy

state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
sema uint32
}

这个结构体包含了 3 个字段:

  • noCopy:用于防止 sync.WaitGroup 被复制。进一步去看 noCopy 的定义会看到就是一个空结构体,查了一些资料后得知主要是用于标识 sync.WaitGroup 不能被复制,如果代码中复制了 sync.WaitGroup,那么在静态检查时 vet 工具能够识别并报错。
  • state:一个 atomic.Uint64 原子类型的字段,用于存储 WaitGroup 的状态。根据注释可知其高 32 位表示 WaitGroup 的计数器(counter) 的值,低 32 位表示等待者(waiter)的值,后面分析源码后可知这里的等待者就是等待的 goroutine 数量。
  • sema:一个 uint32 的字段,根据变量名可猜测其用于实现信号量的功能。而基于后面的源码可确定,调用 wg.Wait() 时的阻塞和唤醒都依赖这个信号量,其主要用于阻塞和唤醒 waiter

sync.Add 方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
func (wg *WaitGroup) Add(delta int) {
// ... 省略部分 race 静态检查代码

state := wg.state.Add(uint64(delta) << 32) // 将 delta 左移 32 位后与 state 进行原子加法,也就是 counter + delta
v := int32(state >> 32) // 取出高 32 位的 counter
w := uint32(state) // 取出低 32 位的 waiter

// ... 省略部分 race 静态检查代码

// 判断 counter 是否小于 0,如果小于 0 则 panic
if v < 0 {
panic("sync: negative WaitGroup counter")
}

// 判断 waiter 是否不等于 0 且 delta 大于 0 且 counter 等于 delta,如果是则 panic。
// w != 0 表示已经有 goroutine 在等待了,delta > 0 表示当前是增加计数器的操作,v == int32(delta) 表示当前的计数器值等于 delta。
// 这三者同时成立说明当前的 goroutine 在等待的同时又增加了计数器的值,这种情况会报错。
// 也就是在 WaitGroup 的使用中,不能在 Wait 之后再调用 Add 方法。
if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}

// 如果 counter 大于 0 或者 waiter 等于 0,则直接返回。因为此时说明对 state 的修改已经完成了。
if v > 0 || w == 0 {
return
}

// 运行到此处说明 counter(变量 v) == 0 且 waiter(变量 w) > 0。
// 造成这个情况的原因可能是:
// - Add 和 Wait 方法并发调用
// - 在 Add 之前调用了 Wait 方法

// 这里进行了一个简单的健全性检查,以检测 WaitGroup 的误用:判断当前的 state 是否等于之前的 state,如果不等则 panic。
// state 变量是之前函数开始获取增加了 delta 后的值,而 wg.state.Load() 是当前的 state。如果不等则说明在运行到这行代码之前,state 的值已经被其他 goroutine 修改了。
// 这里是为了检测并发调用 Add 和 Wait 方法的情况。
if wg.state.Load() != state {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}

// 重置 state 为 0,此时 counter 的值已经是 0 了,所以这里的操作只是为了将 waiter 的值置为 0。
wg.state.Store(0)

// 释放 waiter 的信号量,也就是唤醒所有等待的 goroutine。
for ; w != 0; w-- {
runtime_Semrelease(&wg.sema, false, 0)
}
}

上面代码中删去了部分 race 静态检查的代码,主要是为了简化代码的阅读。race 相关的代码主要是用于检测数据竞争的,具体还没有深入研究,暂且按下不表。

runtime_SemreleaseGo 语言底层 runtime 实现的用于唤醒当前 goroutine 的函数,就不深究了。

至此 Add 方法的实现也就分析完了,总结一下:该方法主要是用于给 WaitGroup 的计数器值加上 delta,并且在计数器为 0 且有等待者的情况下,唤醒所有等待的 goroutine。同时也会对一些错误的使用方式进行检查,比如在 Wait 方法之前调用 Add 方法,或者在 Add 方法中调用 Wait 方法等。

sync.Done 方法

1
2
3
func (wg *WaitGroup) Done() {
wg.Add(-1)
}

Done 方法的实现非常简单,就是调用 Add(-1) 方法来将计数器减 1。也就是在 goroutine 完成时调用 Done() 方法来通知 WaitGroup,表示当前的 goroutine 已经完成了。

而上面分析完 Add 方法后会发现,delta 这个值是支持为负数的,所以 Done 方法也可以直接调用 Add 方法来实现。这里的 Done 方法只是为了语义上的清晰,表示当前的 goroutine 已经完成了。而不是直接调用 Add(-1) 方法。

sync.Wait 方法

分析完 AddDone 方法后,接下来分析最后一个方法: Wait 方法。还是删除了部分 race 静态检查的代码来简化代码的阅读。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
func (wg *WaitGroup) Wait() {
for {
// 获取当前的 state 值,并根据其高 32 位和低 32 位分别获取 counter 和 waiter 的值
state := wg.state.Load()
v := int32(state >> 32)
w := uint32(state)

// 如果 counter 等于 0,则说明没有 goroutine 在等待,直接返回
if v == 0 {
return
}

// 这里用了一个 CAS(Compare And Swap)操作来增加 waiter 的值
// 也就是将 state 的低 32 位加 1,表示有一个 goroutine 在等待
// 如果 CAS 操作成功,则说明当前的 goroutine 成功地增加了 waiter 的值
if wg.state.CompareAndSwap(state, state+1) {
// 阻塞当前的 goroutine,等待其他 goroutine 完成
runtime_SemacquireWaitGroup(&wg.sema)

// 这里再次检查 WaitGroup 的状态,如果不等于 0,则说明又有其他 goroutine 修改了 state 的值
// 这里的 panic 是为了检测 WaitGroup 的误用:并发调用 Add 和 Wait 方法
if wg.state.Load() != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}

// 若上面的判断通过,则说明当前的 goroutine 成功地获取了信号量,表示所有的 goroutine 都已经完成了
return
}
}
}

上面的代码中,用了一个 for 循环,主要是因为 wg.state.CompareAndSwap(state, state+1) 方法是一个原子性的 CAS 操作,所谓的 CAS 操作,就是先 CompareSwap。当调用 wg.state.CompareAndSwap(state, state+1) 时,CompareAndSwap 方法会先判断 wg.state 值是否等于传进来的第一个参数 state,如果相等,则将其替换为第二个参数 state+1 的值,并返回 true;如果 wg.state 值与 state 不相等,则不会修改 wg.state,并返回 false。这样,就保证了对 wg.state 的修改是原子性的。

想象在一个并发场景中,可能有多个 goroutine 同时调用 Wait 方法,这个时候就需要用到 CAS 操作来保证对 wg.state 的修改是安全的。而如果 wg.state 的值已经被其他 goroutine 修改了,那么就需要重新获取 wg.state 的值,然后再进行 CAS 操作。这样就需要一个循环来不断地尝试获取信号量,直到成功为止。

在操作成功后,就调用 runtime_SemacquireWaitGroup(&wg.sema) 方法来阻塞当前的 goroutine,等待其他 goroutine 完成。等待其他 goroutine 完成后,就会调用 wg.Done() 方法来通知 WaitGroup,表示当前的 goroutine 已经完成了。这样就运行上之前看的 Add 方法中的 runtime_Semrelease(&wg.sema, false, 0) 方法来唤醒所有等待的 goroutine

此时 Wait 方法就会继续执行到判断 wg.state.Load() != 0 的地方,如果不等于 0,则说明有其他 goroutine 修改了 state 的值,这里就会 panic。因为前面看过 Add 方法的实现,在调用 runtime_Semrelease(&wg.sema, false, 0) 方法前已经将 state 的值置为 0 了,所以这里如果 state 的值不等于 0,说明又有并发调用 WaitAdd 方法的情况发生了。

而如果 wg.state.Load() == 0,则说明所有的 goroutine 都已经完成了,当前的 goroutine 也就可以返回了。

总结来说就是管理 state 中的 waiter 的值,来实现对 goroutine 的阻塞并等待被 Done 方法唤醒。

总结

分析完 sync.WaitGroup 的源码后,基本上就能理解 sync.WaitGroup 的实现原理了。总结一下:

  • sync.WaitGroup 是一个用于等待一组 goroutine 完成的同步原语。
  • AddDone 方法用于修改计数器的值,维护 sync.WaitGroup 内部的计数器 counter 和等待者 waiter 的值。
  • 注意在使用 sync.WaitGroup 时,不能并发调用 AddWait 方法,否则会导致 panic

sync.ErrGroup

sync.ErrGroup 是基于 sync.WaitGroup 实现的,但是它和 sync.WaitGroup 的主要区别在于:

  • 限制并发数sync.ErrGroup 可以限制并发的数量,而 sync.WaitGroup 不支持。
  • 错误处理sync.ErrGroup 可以在其中一个 goroutine 出现错误时,立即停止所有的 goroutine 的执行,而 sync.WaitGroup 只能等待所有的 goroutine 完成。
  • 返回错误sync.ErrGroup 可以返回第一个出现的错误,而 sync.WaitGroup 只能等待所有的 goroutine 完成,无法返回错误。

使用示例

基本使用

下面是一个使用 sync.ErrGroup 的简单示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
func main() {
// 创建一个 errgroup.Group 实例
var g errgroup.Group

// 启动 5 个 goroutine
for i := 0; i < 5; i++ {
g.Go(func() error {
time.Sleep(time.Duration(i) * time.Second)
if i == 3 {
return fmt.Errorf("error in goroutine %d", i)
}
fmt.Println("Hello from goroutine", i)
return nil
})
}

// 等待所有 goroutine 完成
if err := g.Wait(); err != nil {
fmt.Println("Error:", err)
}
}

几本用法和 sync.WaitGroup 一样,初始化 errgroup.Group 后,启动协程,然后调用 g.Wait() 方法等待所有的 goroutine 完成。一样的零值可用,不必显式初始化。

不同的是这里不再通过 AddDone 方法来管理 goroutine 的数量,而是通过 g.Go() 方法来启动 goroutine。在 g.Go() 方法中可以返回一个错误,如果有任何一个 goroutine 返回了错误,g.Wait() 方法会立即返回这个错误,并且会停止其他 goroutine 的运行。

限制并发数

sync.ErrGroup 内部已经实现了限制并发数的功能,所以可以直接使用 g.Go() 方法来启动 goroutine,并且用 SetLimit 方法来限制并发数。下面是一个使用 sync.ErrGroup 限制并发数的示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
func main() {
// 创建一个 errgroup.Group 实例
var g errgroup.Group

// 设置并发数为 2
g.SetLimit(2)

// 启动 5 个 goroutine
// 这里的 goroutine 会被限制为 2 个并发执行
// 也就是每次只会有 2 个 goroutine 在执行
// 其他的 goroutine 会被阻塞,直到有 goroutine 完成
for i := 0; i < 5; i++ {
g.Go(func() error {
time.Sleep(time.Duration(i) * time.Second)
if i == 3 {
return fmt.Errorf("error in goroutine %d", i)
}
fmt.Println("Hello from goroutine", i)
return nil
})
}

// 等待所有 goroutine 完成
if err := g.Wait(); err != nil {
fmt.Println("Error:", err)
}
}

上面的代码中,调用 g.SetLimit(2) 方法来限制并发数为 2,也就是每次只会有 2goroutine 在执行,其他的 goroutine 会被阻塞,直到有 goroutine 完成。
这样就可以控制并发的数量,避免过多的 goroutine 同时执行导致的资源竞争和性能问题。

Context 取消

sync.ErrGroup 还可以与 context.Context 一起使用,以便在需要时取消所有的 goroutine。下面是一个使用 sync.ErrGroupcontext.Context 的示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
func main() {
g, ctx := errgroup.WithContext(context.Background())

// 启动 5 个 goroutine
for i := 0; i < 5; i++ {
g.Go(func() error {
select {
case <-ctx.Done():
return ctx.Err()
default:
time.Sleep(time.Duration(i) * time.Second)
if i == 3 {
return fmt.Errorf("error in goroutine %d", i)
}
fmt.Println("Hello from goroutine", i)
return nil
}
})
}

// 等待所有 goroutine 完成
if err := g.Wait(); err != nil {
fmt.Println("Error:", err)
}
}

上面的代码中,创建了一个 context.Context 实例,并在 g.Go() 方法中使用 select 语句来监听 ctx.Done() 的信号。如果接收到取消信号,则返回 ctx.Err() 错误。这样就可以在需要时取消所有的 goroutine

尝试启动

sync.ErrGroup 还可以尝试启动 goroutine,也就是执行 g.TryGo() 方法来调用函数,如果函数运行成功,则返回 true,否则返回 false。注意这里的尝试运行,并非是说函数本身是否报错,而是说当前 group 运行的 goroutine 是否已经到达了最大限制。所以如果要使用 TryGo 方法来启动 goroutine,需要先设置最大限制。下面是一个使用 sync.ErrGroup 尝试启动 goroutine 的示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
func main() {
// 创建一个 errgroup.Group 实例
var g errgroup.Group

// 设置并发数为 2
g.SetLimit(2)

// 启动 5 个 goroutine
for i := 0; i < 5; i++ {
if !g.TryGo(func() error {
time.Sleep(time.Duration(i) * time.Second)
fmt.Println("Hello from goroutine", i)
return nil
}) {
fmt.Println("Failed to start goroutine", i)
} else {
fmt.Println("Successfully started goroutine", i)
}
}

// 等待所有 goroutine 完成
if err := g.Wait(); err != nil {
fmt.Println("Error:", err)
}
}

注意两点:

  • 这里如果没有用 SetLimit 方法设置最大限制,则所有的 TryGo 方法都会返回 true,也就是没有限制。
  • TryGo 方法的返回值是 bool 类型,如果返回 false,则说明当前的 goroutine 已经达到最大限制,无法再启动新的 goroutine。但是如果使用 g.Go() 方法,则会阻塞,直到有 goroutine 完成后再启动新的 goroutine

sync.ErrGroup 源码阅读

sync.ErrGroup 的源码在非常简单的 sync 包中,源码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
// Package errgroup provides synchronization, error propagation, and Context
// cancelation for groups of goroutines working on subtasks of a common task.
//
// [errgroup.Group] is related to [sync.WaitGroup] but adds handling of tasks
// returning errors.
package errgroup

import (
"context"
"fmt"
"runtime"
"runtime/debug"
"sync"
)

type token struct{}

// A Group is a collection of goroutines working on subtasks that are part of
// the same overall task. A Group should not be reused for different tasks.
//
// A zero Group is valid, has no limit on the number of active goroutines,
// and does not cancel on error.
type Group struct {
cancel func(error)

wg sync.WaitGroup

sem chan token

errOnce sync.Once
err error

mu sync.Mutex
panicValue any // = PanicError | PanicValue; non-nil if some Group.Go goroutine panicked.
abnormal bool // some Group.Go goroutine terminated abnormally (panic or goexit).
}

func (g *Group) done() {
if g.sem != nil {
<-g.sem
}
g.wg.Done()
}

// WithContext returns a new Group and an associated Context derived from ctx.
//
// The derived Context is canceled the first time a function passed to Go
// returns a non-nil error or the first time Wait returns, whichever occurs
// first.
func WithContext(ctx context.Context) (*Group, context.Context) {
ctx, cancel := context.WithCancelCause(ctx)
return &Group{cancel: cancel}, ctx
}

// Wait blocks until all function calls from the Go method have returned
// normally, then returns the first non-nil error (if any) from them.
//
// If any of the calls panics, Wait panics with a [PanicValue];
// and if any of them calls [runtime.Goexit], Wait calls runtime.Goexit.
func (g *Group) Wait() error {
g.wg.Wait()
if g.cancel != nil {
g.cancel(g.err)
}
if g.panicValue != nil {
panic(g.panicValue)
}
if g.abnormal {
runtime.Goexit()
}
return g.err
}

// Go calls the given function in a new goroutine.
// The first call to Go must happen before a Wait.
// It blocks until the new goroutine can be added without the number of
// active goroutines in the group exceeding the configured limit.
//
// It blocks until the new goroutine can be added without the number of
// goroutines in the group exceeding the configured limit.
//
// The first goroutine in the group that returns a non-nil error, panics, or
// invokes [runtime.Goexit] will cancel the associated Context, if any.
func (g *Group) Go(f func() error) {
if g.sem != nil {
g.sem <- token{}
}

g.add(f)
}

func (g *Group) add(f func() error) {
g.wg.Add(1)
go func() {
defer g.done()
normalReturn := false
defer func() {
if normalReturn {
return
}
v := recover()
g.mu.Lock()
defer g.mu.Unlock()
if !g.abnormal {
if g.cancel != nil {
g.cancel(g.err)
}
g.abnormal = true
}
if v != nil && g.panicValue == nil {
switch v := v.(type) {
case error:
g.panicValue = PanicError{
Recovered: v,
Stack: debug.Stack(),
}
default:
g.panicValue = PanicValue{
Recovered: v,
Stack: debug.Stack(),
}
}
}
}()

err := f()
normalReturn = true
if err != nil {
g.errOnce.Do(func() {
g.err = err
if g.cancel != nil {
g.cancel(g.err)
}
})
}
}()
}

// TryGo calls the given function in a new goroutine only if the number of
// active goroutines in the group is currently below the configured limit.
//
// The return value reports whether the goroutine was started.
func (g *Group) TryGo(f func() error) bool {
if g.sem != nil {
select {
case g.sem <- token{}:
// Note: this allows barging iff channels in general allow barging.
default:
return false
}
}

g.add(f)
return true
}

// SetLimit limits the number of active goroutines in this group to at most n.
// A negative value indicates no limit.
// A limit of zero will prevent any new goroutines from being added.
//
// Any subsequent call to the Go method will block until it can add an active
// goroutine without exceeding the configured limit.
//
// The limit must not be modified while any goroutines in the group are active.
func (g *Group) SetLimit(n int) {
if n < 0 {
g.sem = nil
return
}
if len(g.sem) != 0 {
panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", len(g.sem)))
}
g.sem = make(chan token, n)
}

// PanicError wraps an error recovered from an unhandled panic
// when calling a function passed to Go or TryGo.
type PanicError struct {
Recovered error
Stack []byte // result of call to [debug.Stack]
}

func (p PanicError) Error() string {
// A Go Error method conventionally does not include a stack dump, so omit it
// here. (Callers who care can extract it from the Stack field.)
return fmt.Sprintf("recovered from errgroup.Group: %v", p.Recovered)
}

func (p PanicError) Unwrap() error { return p.Recovered }

// PanicValue wraps a value that does not implement the error interface,
// recovered from an unhandled panic when calling a function passed to Go or
// TryGo.
type PanicValue struct {
Recovered any
Stack []byte // result of call to [debug.Stack]
}

func (p PanicValue) String() string {
if len(p.Stack) > 0 {
return fmt.Sprintf("recovered from errgroup.Group: %v\n%s", p.Recovered, p.Stack)
}
return fmt.Sprintf("recovered from errgroup.Group: %v", p.Recovered)
}

注释中写得很清楚,errgroup 包提供了同步、错误传递和上下文取消的功能,其主要用于实现一批 goroutine 的协作。errgroup.Groupsync.WaitGroup 是相关的,但添加了处理任务返回的错误的功能。

errgroup.Group 结构体

初始化声明的 errgroup.Group 结构体如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
type Group struct {
cancel func(error)

wg sync.WaitGroup

sem chan token

errOnce sync.Once
err error

mu sync.Mutex
panicValue any // = PanicError | PanicValue; non-nil if some Group.Go goroutine panicked.
abnormal bool // some Group.Go goroutine terminated abnormally (panic or goexit).
}

这个包所有对外暴露的方法都关联在这个结构体上,下面分析一下这个结构体的字段:

  • cancel:一个函数,用于取消 context.Context,如果有错误发生,则调用这个函数来取消 context.Context
  • wg:一个 sync.WaitGroup,用于等待所有的 goroutine 完成。
  • sem:一个信号量,用于限制并发的数量。这里的 token 是一个空结构体,空结构体不占用内存,所以常被用来基于 chan 实现信号量的传递。
  • errOnce:一个 sync.Once,用于确保只调用一次错误处理函数。
  • err:一个错误,用于存储第一个出现的错误。
  • mu:一个互斥锁,用于保护 panicValueabnormal 字段的并发访问。
  • panicValue:一个 any 类型的字段,用于存储 goroutine 中发生的 panic 错误。
  • abnormal:一个布尔值,表示是否有 goroutine 异常终止(panicgoexit)。

仅看结构体的定义,基本上就能理解这个结构体的作用了。主要是用于管理 goroutine 的并发数量、错误处理和异常终止的情况。具体的实现逻辑还需要结合方法来分析。

errgroup.WithContext 方法

1
2
3
4
func WithContext(ctx context.Context) (*Group, context.Context) {
ctx, cancel := context.WithCancelCause(ctx)
return &Group{cancel: cancel}, ctx
}

WithContext 方法用于创建一个新的 errgroup.Group 实例,并返回一个与之关联的 context.Context。这个方法的主要作用是将 context.Contexterrgroup.Group 结合起来,以便在需要时取消所有的 goroutine

context.WithCancelCause 方法用于创建一个可以取消的上下文,并返回一个取消函数 cancel。这个取消函数会被添加到结构体中,后面的 Wait 方法中会调用这个函数来取消上下文。返回的 ctx 是一个新的上下文,包含了取消函数和原始上下文的值。

errgroup.SetLimit 方法

在使用中,errgroup.Group 的并发数量是通过 errgroup.SetLimit 方法设定的,下面是 SetLimit 方法的实现:

1
2
3
4
5
6
7
8
9
10
func (g *Group) SetLimit(n int) {
if n < 0 {
g.sem = nil
return
}
if len(g.sem) != 0 {
panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", len(g.sem)))
}
g.sem = make(chan token, n)
}

其实本质上就是创建 ntoken (空结构体)的信号量,n 为负数时表示不限制并发数量。其实感觉这里还需要检查一下 n 的值是否为 0

另外这里在修改 sem 的值前,检查了当前的 sem 是否为空,如果不为空则会 panic。这个检查是为了防止在有 goroutine 正在运行时修改并发数量的限制。所以在使用中尤其要注意,在调用 errgroup.Goerrgroup.TryGo 方法之前,才能调用 SetLimit 方法来设置并发数量的限制。否则会导致 panic

errgroup.Go 方法

Go 方法用于启动一个新的 goroutine,并将其添加到 errgroup.Group 中。下面是 Go 方法的实现:

1
2
3
4
5
6
7
func (g *Group) Go(f func() error) {
if g.sem != nil {
g.sem <- token{}
}

g.add(f)
}

这里首先检查了 g.sem 是否为 nil,如果不为 nil,则表示有设置并发数量的限制,这里就会向 sem 中发送一个 token。这个操作在 sem 已经被写满后阻塞,直到有 goroutine 完成后(可推测:完成一个协程就从 g.sem 中读取一个信号量)才能继续执行。

代码中调用了一个内置的 add 方法来添加 goroutine,下面是 add 方法的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
func (g *Group) add(f func() error) {
g.wg.Add(1)
go func() {
defer g.done()
normalReturn := false
defer func() {
if normalReturn {
return
}
v := recover()
g.mu.Lock()
defer g.mu.Unlock()
if !g.abnormal {
if g.cancel != nil {
g.cancel(g.err)
}
g.abnormal = true
}
if v != nil && g.panicValue == nil {
switch v := v.(type) {
case error:
g.panicValue = PanicError{
Recovered: v,
Stack: debug.Stack(),
}
default:
g.panicValue = PanicValue{
Recovered: v,
Stack: debug.Stack(),
}
}
}
}()

err := f()
normalReturn = true
if err != nil {
g.errOnce.Do(func() {
g.err = err
if g.cancel != nil {
g.cancel(g.err)
}
})
}
}()
}

函数首先调用 g.wg.Add(1) 来将 WaitGroup 的计数器加 1,表示有一个新的 goroutine 被添加进来。然后启动一个新的 goroutine 来执行传入的函数 f

这个 goroutine 中使用了两个 defer 语句来区别正常返回和异常返回的情况。仔细看的话和 singleflight.doCall 方法中的实现有点类似,又有些区别。先来看这个 errgroup.add 方法的实现:

  • 首先运行函数 err := f() 来执行传入的函数 f,如果函数返回了错误,则调用 g.errOnce.Do 方法来设置错误值,并通过 g.cancel(g.err) 来取消上下文。这里的 g.errOnce 是一个 sync.Once,表示只会执行一次。也就是如果有多个 goroutine 返回了错误,则只会返回第一个错误。
  • 无论是否返回错误,只要函数执行完成(没有 panicgoexit),都会将 normalReturn 设置为 true,表示正常返回。
  • 无论函数正常运行还是发生错误,此时首先都会进入到第二个 defer 语句中。
    • 检查 normalReturn 是否为 true,如果是,则说明函数运行没有发生 panicgoexit,直接返回。
    • 否则通过 recover() 方法来捕获 panic 的值。然后加锁,防止并发访问。这里的 g.mu.Lock()g.mu.Unlock() 全局只有该处使用,这样可以防止多个 goroutine 同时出错导致的错误后同时读写结构体内的字段。其中修改了结构体中 abnormalpanicValue 字段的值,abnormal 字段后面会在 Wait 方法中使用到,而 panicValue 字段则是用于存储 panic 的值。这里会进行一次 panicValue 字段的判空,来实现只会存储第一个报错的 panic 值。即便错误并发出现后,其他没有拿到锁的 goroutine 在这个值已经被设置后就不会再设置了。
  • 最后进入第一个 defer 语句中,执行的是 g.done() 方法,主要就是释放一个信号量并将 WaitGroup 的计数器减 1。这个方法的实现如下:
1
2
3
4
5
6
func (g *Group) done() {
if g.sem != nil {
<-g.sem
}
g.wg.Done()
}

done 方法中首先检查 g.sem 是否为 nil,如果不为 nil,则从 sem 中读取一个 token,表示有一个 goroutine 完成了。然后调用 g.wg.Done() 来将 WaitGroup 的计数器减 1

至此 Go 方法的实现也就分析完了,主要就是通过 g.wg.Add(1) 来将计数器加 1,然后启动一个新的 goroutine 来执行传入的函数 f。在函数执行完成后,通过 g.done() 方法来将计数器减 1,并释放一个信号量。
如果函数执行过程中发生了 panic,则通过 recover() 方法来捕获 panic 的值,并将其存储到 panicValue 字段中。
如果函数执行过程中返回了错误,则通过 g.errOnce.Do 方法来设置错误值,并取消上下文。

errgroup.Wait 方法

上面的 Go 方法中在运行的过程中会通过 g.wg.Add(1) 来将计数器加 1,而在 Wait 方法中则是通过 g.wg.Wait() 来等待所有的 goroutine 完成。下面是 Wait 方法的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
func (g *Group) Wait() error {
g.wg.Wait()
if g.cancel != nil {
g.cancel(g.err)
}
if g.panicValue != nil {
panic(g.panicValue)
}
if g.abnormal {
runtime.Goexit()
}
return g.err
}

Wait 方法用于等待所有的 goroutine 完成,并返回第一个出现的错误(如果 panicValue 存在的话)。

这个方法通过调用 g.wg.Wait() 来等待所有的 goroutine 完成,这里的 g.wg 就是 sync.WaitGroup。作为一个零值可用的数据类型,只需要声明 errgroup.Group 即可使用。

等到所有的 goroutine 完成后,调用 g.cancel(g.err) 来取消上下文,并传入第一个出现的错误(如果有的话)。如果 panicValue 不为 nil,则调用 panic(g.panicValue) 来抛出之前用 recover 捕获的异常。

而如果 abnormaltrue,则调用 runtime.Goexit() 来退出当前的 goroutine

这里又看到了 abnormal 这个字段,之前在Go 并发控制:singleflight 详解中也看到过类似的实现,用于区分是函数运行出现 panic 还是由于协程退出导致的 runtime.Goexit()。这里的实现方法不同,但功能类似,首先 Wait 函数在已经结束阻塞后会有三种情况:

  • 有协程在运行过程中发生了 panic,那么在 errgroup.add 方法中就会将 g.panicValue 设置为捕获的异常,然后在 Wait 方法中再次被 panic
  • 如果 g.panicValuenil,则判断 g.abnormal 是否为 true(如果函数运行被中断则在 add 方法中会被修改),如果是,则说明协程运行被中断不是由于 panic 而是程序的退出,那么就调用 runtime.Goexit() 来退出当前的 goroutine
  • 如果 g.panicValueg.abnormal 都为 nil,则说明所有的 goroutine 都正常运行完成了,那么就返回第一个出现的错误(如果有的话)。

errgroup.TryGo 方法

上面已经的源码已经走完了整个使用流程,下面再分析一下使用 TryGo 方法而非 Go 方法的底层实现。TryGo 方法的实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
func (g *Group) TryGo(f func() error) bool {
if g.sem != nil {
select {
case g.sem <- token{}:
// Note: this allows barging iff channels in general allow barging.
default:
return false
}
}

g.add(f)
return true
}

TryGo 方法和 Go 方法的实现类似,主要区别在于 TryGo 方法不会阻塞,而是尝试向 sem 中发送一个 token。如果发送成功,则表示可以启动一个新的 goroutine,否则返回 false
这里使用了 select 语句来尝试向 sem 中发送一个 token,如果发送成功,则表示可以启动一个新的 goroutine,否则返回 false。如果 semnil,则表示没有限制并发数量,可以直接调用 g.add(f) 方法来添加 goroutine

如果 sem 不为 nil,则使用 select 语句来尝试向 sem 中发送一个 token,如果发送成功,则表示可以启动一个新的 goroutine,否则返回 false。这里的 default 分支是为了避免阻塞,如果没有可用的信号量,则直接返回 false

这样就可以实现尝试启动 goroutine 的功能,而不需要阻塞等待信号量的释放。

如果发送成功,则调用 g.add(f) 方法来添加 goroutine,并返回 true

总结

sync.WaitGroupsync.ErrGroupGo 语言中用于处理并发的两个重要的同步原语。它们都可以用于等待一组 goroutine 完成,但在功能上有一些区别。sync.errgroup 是在 sync.WaitGroup 基础上,增加了错误传递和上下文取消的功能。