fix: some issue

This commit is contained in:
meilin.huang
2025-05-23 17:26:12 +08:00
parent 778cb7f4de
commit 666b191b6c
25 changed files with 758 additions and 98 deletions

View File

@@ -20,6 +20,12 @@ type cacheEntry[T Conn] struct {
lastActive time.Time
}
func (e *cacheEntry[T]) Close() {
if err := e.conn.Close(); err != nil {
logx.Errorf("cache pool - closing connection error: %v", err)
}
}
type CachePool[T Conn] struct {
factory func() (T, error)
mu sync.RWMutex
@@ -48,26 +54,38 @@ func NewCachePool[T Conn](factory func() (T, error), opts ...Option) *CachePool[
// Get 获取连接(自动创建或复用缓存连接)
func (p *CachePool[T]) Get(ctx context.Context) (T, error) {
var zero T
// 先尝试加读锁,仅用于查找可用连接
p.mu.RLock()
for _, entry := range p.cache {
if time.Since(entry.lastActive) <= p.config.IdleTimeout {
p.mu.RUnlock() // 找到后释放读锁
return entry.conn, nil
}
}
p.mu.RUnlock()
// 没有找到可用连接,升级为写锁进行清理和创建
p.mu.Lock()
defer p.mu.Unlock()
var zero T
if p.closed {
return zero, ErrPoolClosed
}
// 1. 尝试从缓存中获取可用连接
// 再次检查是否已有可用连接(防止并发创建)
for key, entry := range p.cache {
if time.Since(entry.lastActive) <= p.config.IdleTimeout {
entry.lastActive = time.Now() // 更新活跃时间
entry.lastActive = time.Now()
return entry.conn, nil
}
// 自动清理闲置超时连接
entry.conn.Close()
// 清理超时连接
entry.Close()
delete(p.cache, key)
}
// 2. 创建新连接并缓存
// 创建新连接
conn, err := p.factory()
if err != nil {
return zero, err
@@ -134,9 +152,7 @@ func (p *CachePool[T]) Close() {
close(p.closeCh)
for _, entry := range p.cache {
if err := entry.conn.Close(); err != nil {
logx.Errorf("cache pool - error closing connection: %v", err)
}
entry.Close()
}
// 触发关闭回调
@@ -197,7 +213,7 @@ func (p *CachePool[T]) cleanupIdle() {
cutoff := time.Now().Add(-p.config.IdleTimeout)
for key, entry := range p.cache {
if entry.lastActive.Before(cutoff) {
entry.conn.Close()
entry.Close()
delete(p.cache, key)
}
}

View File

@@ -17,21 +17,21 @@ var ChanPoolDefaultConfig = PoolConfig{
HealthCheckInterval: 10 * time.Minute,
}
// ConnWrapper 封装连接及其元数据
type ConnWrapper[T Conn] struct {
// chanConn 封装连接及其元数据
type chanConn[T Conn] struct {
conn T
lastActive time.Time // 最后活跃时间
isValid bool // 连接是否有效
}
func (w *ConnWrapper[T]) Ping() error {
func (w *chanConn[T]) Ping() error {
if !w.isValid {
return errors.New("connection marked invalid")
}
return w.conn.Ping()
}
func (w *ConnWrapper[T]) Close() error {
func (w *chanConn[T]) Close() error {
w.isValid = false
return w.conn.Close()
}
@@ -40,7 +40,7 @@ func (w *ConnWrapper[T]) Close() error {
type ChanPool[T Conn] struct {
mu sync.RWMutex
factory func() (T, error)
idleConns chan *ConnWrapper[T]
idleConns chan *chanConn[T]
config PoolConfig
currentConns int32
stats PoolStats
@@ -66,7 +66,7 @@ func NewChannelPool[T Conn](factory func() (T, error), opts ...Option) *ChanPool
// 2. 创建连接池
p := &ChanPool[T]{
factory: factory,
idleConns: make(chan *ConnWrapper[T], config.MaxConns),
idleConns: make(chan *chanConn[T], config.MaxConns),
config: config,
closeChan: make(chan struct{}),
}
@@ -111,6 +111,15 @@ func (p *ChanPool[T]) Get(ctx context.Context) (T, error) {
}
func (p *ChanPool[T]) get() (T, error) {
// 检查连接池是否已关闭
p.mu.RLock()
if p.closed {
p.mu.RUnlock()
var zero T
return zero, ErrPoolClosed
}
p.mu.RUnlock()
// 优先从 channel 获取空闲连接(无锁)
select {
case wrapper := <-p.idleConns:
@@ -185,9 +194,17 @@ func (p *ChanPool[T]) Put(conn T) error {
return nil
}
// 检查连接池是否已关闭
p.mu.RLock()
if p.closed {
p.mu.RUnlock()
return conn.Close()
}
p.mu.RUnlock()
// 快速路径
select {
case p.idleConns <- &ConnWrapper[T]{conn: conn, lastActive: time.Now(), isValid: true}:
case p.idleConns <- &chanConn[T]{conn: conn, lastActive: time.Now(), isValid: true}:
atomic.AddInt32(&p.stats.IdleConns, 1)
atomic.AddInt32(&p.stats.ActiveConns, -1)
return nil
@@ -198,6 +215,11 @@ func (p *ChanPool[T]) Put(conn T) error {
p.mu.Lock()
defer p.mu.Unlock()
// 再次检查是否已关闭
if p.closed {
return conn.Close()
}
// 检查是否超过最大连接数
if atomic.LoadInt32(&p.currentConns) > int32(p.config.MaxConns) {
conn.Close()
@@ -205,7 +227,7 @@ func (p *ChanPool[T]) Put(conn T) error {
} else {
// 直接放入空闲队列
select {
case p.idleConns <- &ConnWrapper[T]{conn: conn, lastActive: time.Now(), isValid: true}:
case p.idleConns <- &chanConn[T]{conn: conn, lastActive: time.Now(), isValid: true}:
default:
conn.Close()
atomic.AddInt32(&p.currentConns, -1)
@@ -228,7 +250,7 @@ func (p *ChanPool[T]) Close() {
close(p.closeChan)
// 2. 临时转移空闲连接
idle := make([]*ConnWrapper[T], 0, len(p.idleConns))
idle := make([]*chanConn[T], 0, len(p.idleConns))
for len(p.idleConns) > 0 {
idle = append(idle, <-p.idleConns)
}
@@ -269,7 +291,7 @@ func (p *ChanPool[T]) checkIdleConns() {
return
}
idle := make([]*ConnWrapper[T], 0, len(p.idleConns))
idle := make([]*chanConn[T], 0, len(p.idleConns))
for len(p.idleConns) > 0 {
idle = append(idle, <-p.idleConns)
}
@@ -303,7 +325,7 @@ func (p *ChanPool[T]) Resize(newMaxConns int) {
closed := 0
// 非阻塞取出待关闭的连接
var wrappers []*ConnWrapper[T]
var wrappers []*chanConn[T]
for len(p.idleConns) > 0 && closed < toClose {
wrappers = append(wrappers, <-p.idleConns)
closed++
@@ -318,7 +340,7 @@ func (p *ChanPool[T]) Resize(newMaxConns int) {
}
// 重建空闲连接通道(无需迁移连接,因 channel 本身无状态)
p.idleConns = make(chan *ConnWrapper[T], newMaxConns)
p.idleConns = make(chan *chanConn[T], newMaxConns)
}
func (p *ChanPool[T]) CheckLeaks() []T {
@@ -329,7 +351,7 @@ func (p *ChanPool[T]) CheckLeaks() []T {
now := time.Now()
// 检查所有空闲连接
idle := make([]*ConnWrapper[T], 0, len(p.idleConns))
idle := make([]*chanConn[T], 0, len(p.idleConns))
for len(p.idleConns) > 0 {
idle = append(idle, <-p.idleConns)
}

View File

@@ -2,11 +2,13 @@ package pool
import (
"mayfly-go/pkg/logx"
"sync"
"golang.org/x/sync/singleflight"
)
type PoolGroup[T Conn] struct {
mu sync.RWMutex
poolGroup map[string]Pool[T]
createGroup singleflight.Group
}
@@ -23,14 +25,33 @@ func (pg *PoolGroup[T]) GetOrCreate(
poolFactory func() Pool[T],
opts ...Option,
) (Pool[T], error) {
// 先尝试读锁获取
pg.mu.RLock()
if p, ok := pg.poolGroup[key]; ok {
pg.mu.RUnlock()
return p, nil
}
pg.mu.RUnlock()
// 使用 singleflight 确保并发安全
v, err, _ := pg.createGroup.Do(key, func() (any, error) {
// 再次检查,避免在等待期间其他 goroutine 已创建
pg.mu.RLock()
if p, ok := pg.poolGroup[key]; ok {
pg.mu.RUnlock()
return p, nil
}
pg.mu.RUnlock()
// 创建新池
logx.Infof("pool group - create pool, key: %s", key)
p := poolFactory()
// 写入时加写锁
pg.mu.Lock()
pg.poolGroup[key] = p
pg.mu.Unlock()
return p, nil
})
@@ -56,6 +77,9 @@ func (pg *PoolGroup[T]) GetCachePool(key string, factory func() (T, error), opts
}
func (pg *PoolGroup[T]) Close(key string) error {
pg.mu.Lock()
defer pg.mu.Unlock()
if p, ok := pg.poolGroup[key]; ok {
logx.Infof("pool group - close pool, key: %s", key)
p.Close()
@@ -66,11 +90,24 @@ func (pg *PoolGroup[T]) Close(key string) error {
}
func (pg *PoolGroup[T]) CloseAll() {
pg.mu.Lock()
defer pg.mu.Unlock()
for key := range pg.poolGroup {
pg.Close(key)
pg.poolGroup[key].Close()
pg.createGroup.Forget(key)
}
pg.poolGroup = make(map[string]Pool[T])
}
func (pg *PoolGroup[T]) AllPool() map[string]Pool[T] {
return pg.poolGroup
pg.mu.RLock()
defer pg.mu.RUnlock()
// 返回 map 的副本,避免外部修改
pools := make(map[string]Pool[T], len(pg.poolGroup))
for k, v := range pg.poolGroup {
pools[k] = v
}
return pools
}

View File

@@ -0,0 +1,605 @@
package pool
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
)
/**
这个单元测试是ai写的
提示词如下:
我写了一个连接池和缓存池 请帮忙帮我写一下单元测试,以完成对池的功能检查,需要做到以下目标:
1、连接池内的连接不共用用完连接后需要归还到池中通过监听ctx事件实现
2、连接池允许连接数满后需要等待新的连接等待超时时间为10秒
3、缓存池内的连接可以共用连接超时或不可用时自动删除连接
4、group也并发测试一下另外对各个池做一下压力测试
5、你通过阅读代码尽可能多的覆盖测试
回答如下:
好的你的需求很明确。下面我会为你的连接池ChanPool和缓存池CachePool分别设计详细的单元测试覆盖你提到的所有核心功能点并尽量覆盖更多边界和异常场景。
测试会用 Go 的 testing 包实现,并会模拟一个简单的 Conn 实现用于测试。
测试点设计说明:
1. 连接池ChanPool测试点
不同 goroutine 获取到的连接不是同一个实例。
连接用完后归还池中,池可复用。
连接池满时新请求会等待,超时后报错。
连接池关闭后无法再获取连接。
连接池 Resize 后行为正常。
健康检查能自动剔除失效连接。
2. 缓存池CachePool测试点
获取到的连接可以复用。
连接超时后自动删除。
Put 超过最大连接数时,最久未用的被清理。
池关闭后无法再获取连接。
Resize 后行为正常。
PoolGroup 并发测试:
测试多个 goroutine 并发创建和获取连接池
验证连接池的正确创建和管理
基准测试:
使用 testing.B 进行并发基准测试
测试连接池在高并发下的性能表现
压力测试:
模拟真实场景下的高并发使用
统计成功率和失败率
测试连接池在极限情况下的表现
异常情况测试:
测试连接失效时的行为
测试并发关闭时的行为
验证连接池的健壮性
*/
// mockConn 实现 Conn 接口
type mockConn struct {
id int
closed bool
pingErr error
mu sync.Mutex
}
func (c *mockConn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
c.closed = true
return nil
}
func (c *mockConn) Ping() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return errors.New("closed")
}
return c.pingErr
}
func newMockConn(id int) *mockConn {
return &mockConn{id: id}
}
// ========== ChanPool 单元测试 ==========
func TestChanPool_Basic(t *testing.T) {
var idGen int
pool := NewChannelPool(func() (Conn, error) {
idGen++
return newMockConn(idGen), nil
}, WithMaxConns(2), WithIdleTimeout(time.Second))
ctx := context.Background()
conn1, _ := pool.Get(ctx)
conn2, _ := pool.Get(ctx)
if conn1 == conn2 {
t.Fatal("连接池应返回不同连接")
}
// 归还后可复用
_ = pool.Put(conn1)
conn3, _ := pool.Get(ctx)
if conn3 != conn1 {
t.Fatal("归还的连接应被复用")
}
_ = pool.Put(conn2)
_ = pool.Put(conn3)
pool.Close()
}
func TestChanPool_WaitTimeout(t *testing.T) {
pool := NewChannelPool(func() (Conn, error) {
return newMockConn(1), nil
}, WithMaxConns(1), WithWaitTimeout(100*time.Millisecond))
ctx := context.Background()
conn1, _ := pool.Get(ctx)
// 第二个请求会阻塞并超时
ctx2, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
defer cancel()
start := time.Now()
_, err := pool.Get(ctx2)
if err == nil || time.Since(start) < 100*time.Millisecond {
t.Fatal("应因池满而超时")
}
_ = pool.Put(conn1)
pool.Close()
}
func TestChanPool_ContextCancel(t *testing.T) {
pool := NewChannelPool(func() (Conn, error) {
return newMockConn(1), nil
}, WithMaxConns(1))
ctx, cancel := context.WithCancel(context.Background())
conn, _ := pool.Get(ctx)
cancel()
time.Sleep(10 * time.Millisecond) // 等待归还
_ = pool.Put(conn)
pool.Close()
}
func TestChanPool_Resize(t *testing.T) {
pool := NewChannelPool(func() (Conn, error) {
return newMockConn(1), nil
}, WithMaxConns(2))
ctx := context.Background()
conn1, _ := pool.Get(ctx)
conn2, _ := pool.Get(ctx)
pool.Resize(1)
_ = pool.Put(conn1)
_ = pool.Put(conn2)
pool.Close()
}
func TestChanPool_HealthCheck(t *testing.T) {
pool := NewChannelPool(func() (Conn, error) {
return newMockConn(1), nil
}, WithMaxConns(1), WithIdleTimeout(10*time.Millisecond), WithHealthCheckInterval(10*time.Millisecond))
ctx := context.Background()
conn, _ := pool.Get(ctx)
_ = pool.Put(conn)
time.Sleep(30 * time.Millisecond)
stats := pool.Stats()
if stats.IdleConns != 0 {
t.Fatal("健康检查应清理超时连接")
}
pool.Close()
}
// ========== CachePool 单元测试 ==========
func TestCachePool_Basic(t *testing.T) {
var idGen int
pool := NewCachePool(func() (Conn, error) {
idGen++
return newMockConn(idGen), nil
}, WithMaxConns(2), WithIdleTimeout(time.Second))
ctx := context.Background()
conn1, _ := pool.Get(ctx)
conn2, _ := pool.Get(ctx)
if conn1 != conn2 {
t.Fatal("缓存池应复用同一连接")
}
_ = pool.Put(conn1)
_ = pool.Put(conn2)
pool.Close()
}
func TestCachePool_TimeoutCleanup(t *testing.T) {
pool := NewCachePool(func() (Conn, error) {
return newMockConn(1), nil
}, WithMaxConns(1), WithIdleTimeout(10*time.Millisecond), WithHealthCheckInterval(10*time.Millisecond))
ctx := context.Background()
conn, _ := pool.Get(ctx)
_ = pool.Put(conn)
time.Sleep(30 * time.Millisecond)
stats := pool.Stats()
if stats.TotalConns != 0 {
t.Fatal("超时连接应被清理")
}
pool.Close()
}
func TestCachePool_OverMaxConns(t *testing.T) {
var idGen int
pool := NewCachePool(func() (Conn, error) {
idGen++
return newMockConn(idGen), nil
}, WithMaxConns(1))
ctx := context.Background()
conn1, _ := pool.Get(ctx)
_ = pool.Put(conn1)
conn2, _ := pool.Get(ctx)
_ = pool.Put(conn2)
if conn1 != conn2 {
t.Fatal("缓存池应复用同一连接")
}
// 放入第二个不同连接,最老的会被清理
conn3 := newMockConn(999)
_ = pool.Put(conn3)
if pool.Stats().TotalConns != 1 {
t.Fatal("超出最大连接数应只保留一个")
}
pool.Close()
}
func TestCachePool_Resize(t *testing.T) {
pool := NewCachePool(func() (Conn, error) {
return newMockConn(1), nil
}, WithMaxConns(2))
ctx := context.Background()
conn1, _ := pool.Get(ctx)
_ = pool.Put(conn1)
pool.Resize(1)
if pool.Stats().TotalConns != 1 {
t.Fatal("Resize 后应只保留一个连接")
}
pool.Close()
}
// ========== PoolGroup 并发测试 ==========
func TestPoolGroup_Concurrent(t *testing.T) {
group := NewPoolGroup[Conn]()
var wg sync.WaitGroup
const goroutines = 10
const iterations = 100
// 并发创建和获取连接池
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
key := fmt.Sprintf("pool_%d", id)
pool, err := group.GetChanPool(key, func() (Conn, error) {
return newMockConn(id), nil
})
if err != nil {
t.Errorf("获取连接池失败: %v", err)
return
}
if pool == nil {
t.Error("连接池不应为nil")
return
}
}
}(i)
}
wg.Wait()
// 验证所有池都被正确创建
pools := group.AllPool()
if len(pools) != goroutines {
t.Errorf("期望 %d 个连接池,实际有 %d 个", goroutines, len(pools))
}
// 清理所有池
group.CloseAll()
}
// ========== 压力测试 ==========
func BenchmarkChanPool_Concurrent(b *testing.B) {
pool := NewChannelPool(func() (Conn, error) {
return newMockConn(1), nil
}, WithMaxConns(100))
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
ctx := context.Background()
for pb.Next() {
conn, err := pool.Get(ctx)
if err != nil {
b.Fatal(err)
}
_ = pool.Put(conn)
}
})
pool.Close()
}
func BenchmarkCachePool_Concurrent(b *testing.B) {
pool := NewCachePool(func() (Conn, error) {
return newMockConn(1), nil
}, WithMaxConns(100))
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
ctx := context.Background()
for pb.Next() {
conn, err := pool.Get(ctx)
if err != nil {
b.Fatal(err)
}
_ = pool.Put(conn)
}
})
pool.Close()
}
// 模拟真实场景的压力测试
func TestChanPool_Stress(t *testing.T) {
const (
goroutines = 50
iterations = 1000
)
pool := NewChannelPool(func() (Conn, error) {
return newMockConn(1), nil
}, WithMaxConns(20), WithWaitTimeout(time.Second))
var wg sync.WaitGroup
var errCount int32
var successCount int32
// 添加一个 done channel 用于通知所有 goroutine 停止
done := make(chan struct{})
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < iterations; j++ {
select {
case <-done:
return
default:
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
conn, err := pool.Get(ctx)
if err != nil {
atomic.AddInt32(&errCount, 1)
cancel()
continue
}
// 模拟使用连接
time.Sleep(time.Millisecond)
_ = pool.Put(conn)
atomic.AddInt32(&successCount, 1)
cancel()
}
}
}()
}
// 等待所有操作完成
wg.Wait()
close(done) // 通知所有 goroutine 停止
// 确保所有连接都被正确关闭
pool.Close()
t.Logf("总请求数: %d", goroutines*iterations)
t.Logf("成功请求数: %d", successCount)
t.Logf("失败请求数: %d", errCount)
t.Logf("成功率: %.2f%%", float64(successCount)/float64(goroutines*iterations)*100)
}
func TestCachePool_Stress(t *testing.T) {
const (
goroutines = 50
iterations = 1000
)
pool := NewCachePool(func() (Conn, error) {
return newMockConn(1), nil
}, WithMaxConns(20), WithIdleTimeout(time.Minute))
var wg sync.WaitGroup
var errCount int32
var successCount int32
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < iterations; j++ {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
conn, err := pool.Get(ctx)
if err != nil {
atomic.AddInt32(&errCount, 1)
cancel()
continue
}
// 模拟使用连接
time.Sleep(time.Millisecond)
_ = pool.Put(conn)
atomic.AddInt32(&successCount, 1)
cancel()
}
}()
}
wg.Wait()
pool.Close()
t.Logf("总请求数: %d", goroutines*iterations)
t.Logf("成功请求数: %d", successCount)
t.Logf("失败请求数: %d", errCount)
t.Logf("成功率: %.2f%%", float64(successCount)/float64(goroutines*iterations)*100)
}
// 测试连接池在连接失效时的行为
func TestChanPool_InvalidConn(t *testing.T) {
pool := NewChannelPool(func() (Conn, error) {
conn := newMockConn(1)
conn.pingErr = errors.New("connection invalid")
return conn, nil
}, WithMaxConns(1), WithHealthCheckInterval(10*time.Millisecond))
ctx := context.Background()
conn, _ := pool.Get(ctx)
_ = pool.Put(conn)
// 等待健康检查
time.Sleep(20 * time.Millisecond)
// 获取新连接
newConn, err := pool.Get(ctx)
if err != nil {
t.Fatal("应该能获取到新连接")
}
if newConn == conn {
t.Fatal("应该获取到新的连接实例")
}
_ = pool.Put(newConn)
pool.Close()
}
// 测试连接池在并发关闭时的行为
func TestChanPool_ConcurrentClose(t *testing.T) {
pool := NewChannelPool(func() (Conn, error) {
return newMockConn(1), nil
}, WithMaxConns(10))
var wg sync.WaitGroup
const goroutines = 10
done := make(chan struct{}) // 用于通知所有 goroutine 停止
// 并发获取连接
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-done:
return
default:
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
conn, err := pool.Get(ctx)
if err != nil {
cancel()
continue
}
_ = pool.Put(conn)
cancel()
}
}
}()
}
// 等待一段时间让 goroutine 运行
time.Sleep(100 * time.Millisecond)
// 关闭连接池
pool.Close()
// 通知所有 goroutine 停止
close(done)
// 等待所有 goroutine 完成
wg.Wait()
// 验证连接池已关闭
ctx := context.Background()
_, err := pool.Get(ctx)
if err != ErrPoolClosed {
t.Errorf("期望错误 %v实际错误 %v", ErrPoolClosed, err)
}
}
func TestPoolGroup_ConcurrentAccess(t *testing.T) {
group := NewPoolGroup[Conn]()
var wg sync.WaitGroup
const goroutines = 10
const iterations = 100
// 并发创建和获取连接池
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
key := fmt.Sprintf("pool_%d", id)
pool, err := group.GetChanPool(key, func() (Conn, error) {
return newMockConn(id), nil
})
if err != nil {
t.Errorf("获取连接池失败: %v", err)
return
}
if pool == nil {
t.Error("连接池不应为nil")
return
}
// 模拟使用连接池
ctx := context.Background()
conn, err := pool.Get(ctx)
if err != nil {
t.Errorf("获取连接失败: %v", err)
continue
}
_ = pool.Put(conn)
}
}(i)
}
wg.Wait()
// 验证所有池都被正确创建
pools := group.AllPool()
if len(pools) != goroutines {
t.Errorf("期望 %d 个连接池,实际有 %d 个", goroutines, len(pools))
}
// 并发关闭所有池
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
key := fmt.Sprintf("pool_%d", id)
_ = group.Close(key)
}(i)
}
wg.Wait()
// 验证所有池都已关闭
pools = group.AllPool()
if len(pools) != 0 {
t.Errorf("所有池应已关闭,但还有 %d 个池", len(pools))
}
}
func TestPoolGroup_ConcurrentClose(t *testing.T) {
group := NewPoolGroup[Conn]()
const goroutines = 10
// 先创建一些池
for i := 0; i < goroutines; i++ {
key := fmt.Sprintf("pool_%d", i)
_, err := group.GetChanPool(key, func() (Conn, error) {
return newMockConn(i), nil
})
if err != nil {
t.Fatal(err)
}
}
// 并发关闭所有池
var wg sync.WaitGroup
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
group.CloseAll()
}()
}
wg.Wait()
// 验证所有池都已关闭
pools := group.AllPool()
if len(pools) != 0 {
t.Errorf("所有池应已关闭,但还有 %d 个池", len(pools))
}
}