聊聊gost的GenericTaskPool

/ golang / 没有评论 / 20浏览

本文主要研究一下gost的GenericTaskPool

GenericTaskPool

gost/sync/task_pool.go

// GenericTaskPool represents an generic task pool.
type GenericTaskPool interface {
	// AddTask wait idle worker add task
	AddTask(t task) bool
	// AddTaskAlways add task to queues or do it immediately
	AddTaskAlways(t task)
	// AddTaskBalance add task to idle queue
	AddTaskBalance(t task)
	// Close use to close the task pool
	Close()
	// IsClosed use to check pool status.
	IsClosed() bool
}

GenericTaskPool接口定义了AddTask、AddTaskAlways、AddTaskBalance、Close、IsClosed接口

TaskPool

gost/sync/task_pool.go

type TaskPool struct {
	TaskPoolOptions

	idx    uint32 // round robin index
	qArray []chan task
	wg     sync.WaitGroup

	once sync.Once
	done chan struct{}
}

// return false when the pool is stop
func (p *TaskPool) AddTask(t task) (ok bool) {
	idx := atomic.AddUint32(&p.idx, 1)
	id := idx % uint32(p.tQNumber)

	select {
	case <-p.done:
		return false
	default:
		p.qArray[id] <- t
		return true
	}
}

func (p *TaskPool) AddTaskAlways(t task) {
	id := atomic.AddUint32(&p.idx, 1) % uint32(p.tQNumber)

	select {
	case p.qArray[id] <- t:
		return
	default:
		goSafely(t)
	}
}

// do it immediately when no idle queue
func (p *TaskPool) AddTaskBalance(t task) {
	length := len(p.qArray)

	// try len/2 times to lookup idle queue
	for i := 0; i < length/2; i++ {
		select {
		case p.qArray[rand.Intn(length)] <- t:
			return
		default:
			continue
		}
	}

	goSafely(t)
}

// check whether the session has been closed.
func (p *TaskPool) IsClosed() bool {
	select {
	case <-p.done:
		return true

	default:
		return false
	}
}

func (p *TaskPool) Close() {
	p.stop()
	p.wg.Wait()
	for i := range p.qArray {
		close(p.qArray[i])
	}
}

TaskPool定义了TaskPoolOptions、idx、qArray、wg、once、done属性;它实现了GenericTaskPool接口;AddTask方法在pool是done的时候会返回false,其余情况会递增idx,然后根据tQNumber计算id,往qArray[id]写入task;AddTaskAlways方法会忽略pool的关闭信息;AddTaskBalance方法会尝试len/2次随机往qArray写入task,都写入不成功则goSafely执行;IsClosed主要是读取done的channel信息;Close方法执行stop及wg.Wait(),最后遍历qArray挨个执行close

NewTaskPool

gost/sync/task_pool.go

func NewTaskPool(opts ...TaskPoolOption) GenericTaskPool {
	var tOpts TaskPoolOptions
	for _, opt := range opts {
		opt(&tOpts)
	}

	tOpts.validate()

	p := &TaskPool{
		TaskPoolOptions: tOpts,
		qArray:          make([]chan task, tOpts.tQNumber),
		done:            make(chan struct{}),
	}

	for i := 0; i < p.tQNumber; i++ {
		p.qArray[i] = make(chan task, p.tQLen)
	}
	p.start()

	return p
}

NewTaskPool通过TaskPoolOptions来创建TaskPool

TaskPoolOption

gost/sync/options.go

const (
	defaultTaskQNumber = 10
	defaultTaskQLen    = 128
)

/////////////////////////////////////////
// Task Pool Options
/////////////////////////////////////////

// TaskPoolOptions is optional settings for task pool
type TaskPoolOptions struct {
	tQLen      int // task queue length. buffer size per queue
	tQNumber   int // task queue number. number of queue
	tQPoolSize int // task pool size. number of workers
}

func (o *TaskPoolOptions) validate() {
	if o.tQPoolSize < 1 {
		panic(fmt.Sprintf("illegal pool size %d", o.tQPoolSize))
	}

	if o.tQLen < 1 {
		o.tQLen = defaultTaskQLen
	}

	if o.tQNumber < 1 {
		o.tQNumber = defaultTaskQNumber
	}

	if o.tQNumber > o.tQPoolSize {
		o.tQNumber = o.tQPoolSize
	}
}

type TaskPoolOption func(*TaskPoolOptions)

// WithTaskPoolTaskPoolSize set @size of the task queue pool size
func WithTaskPoolTaskPoolSize(size int) TaskPoolOption {
	return func(o *TaskPoolOptions) {
		o.tQPoolSize = size
	}
}

// WithTaskPoolTaskQueueLength set @length of the task queue length
func WithTaskPoolTaskQueueLength(length int) TaskPoolOption {
	return func(o *TaskPoolOptions) {
		o.tQLen = length
	}
}

// WithTaskPoolTaskQueueNumber set @number of the task queue number
func WithTaskPoolTaskQueueNumber(number int) TaskPoolOption {
	return func(o *TaskPoolOptions) {
		o.tQNumber = number
	}
}

TaskPoolOptions定义了tQLen、tQNumber、tQPoolSize属性,提供了WithTaskPoolTaskPoolSize、WithTaskPoolTaskQueueLength、WithTaskPoolTaskQueueNumber、validate方法

start

gost/sync/task_pool.go

func (p *TaskPool) start() {
	for i := 0; i < p.tQPoolSize; i++ {
		p.wg.Add(1)
		workerID := i
		q := p.qArray[workerID%p.tQNumber]
		p.safeRun(workerID, q)
	}
}

func (p *TaskPool) safeRun(workerID int, q chan task) {
	gxruntime.GoSafely(nil, false,
		func() {
			err := p.run(int(workerID), q)
			if err != nil {
				// log error to stderr
				log.Printf("gost/TaskPool.run error: %s", err.Error())
			}
		},
		nil,
	)
}

start方法根据tQPoolSize挨个执行safeRun;safeRun方法通过GoSafely执行p.run(int(workerID), q)

run

gost/sync/task_pool.go

// worker
func (p *TaskPool) run(id int, q chan task) error {
	defer p.wg.Done()

	var (
		ok bool
		t  task
	)

	for {
		select {
		case <-p.done:
			if 0 < len(q) {
				return fmt.Errorf("task worker %d exit now while its task buffer length %d is greater than 0",
					id, len(q))
			}

			return nil

		case t, ok = <-q:
			if ok {
				func() {
					defer func() {
						if r := recover(); r != nil {
							fmt.Fprintf(os.Stderr, "%s goroutine panic: %v\n%s\n",
								time.Now(), r, string(debug.Stack()))
						}
					}()
					t()
				}()
			}
		}
	}
}

run方法通过for循环进行select,若读取到p.done则退出循环;若是读取到task则执行task

taskPoolSimple

gost/sync/task_pool.go

type taskPoolSimple struct {
	work chan task     // task channel
	sem  chan struct{} // gr pool size

	wg sync.WaitGroup

	once sync.Once
	done chan struct{}
}

taskPoolSimple定义了work、sem、wg、once、done属性;它实现了GenericTaskPool接口;AddTask方法先判断done,之后写入work及sem;AddTaskAlways方法在select不到channel的时候会执行goSafely;AddTaskBalance方法实际执行的是AddTaskAlways;IsClosed方法读取done信息;Close方法执行stop及wg.Wait()

实例

gost/sync/task_pool_test.go

func TestTaskPool(t *testing.T) {
	numCPU := runtime.NumCPU()
	//taskCnt := int64(numCPU * numCPU * 100)

	tp := NewTaskPool(
		WithTaskPoolTaskPoolSize(1),
		WithTaskPoolTaskQueueNumber(1),
		WithTaskPoolTaskQueueLength(1),
	)

	//task, cnt := newCountTask()
	task, _ := newCountTask()

	var wg sync.WaitGroup
	for i := 0; i < numCPU*numCPU; i++ {
		wg.Add(1)
		go func() {
			for j := 0; j < 100; j++ {
				ok := tp.AddTask(task)
				if !ok {
					t.Log(j)
				}
			}
			wg.Done()
		}()
	}
	wg.Wait()
	tp.Close()

	//if taskCnt != atomic.LoadInt64(cnt) {
	//	//t.Error("want ", taskCnt, " got ", *cnt)
	//}
}

func TestTaskPoolSimple(t *testing.T) {
	numCPU := runtime.NumCPU()
	taskCnt := int64(numCPU * numCPU * 100)

	tp := NewTaskPoolSimple(1)

	task, cnt := newCountTask()

	var wg sync.WaitGroup
	for i := 0; i < numCPU*numCPU; i++ {
		wg.Add(1)
		go func() {
			for j := 0; j < 100; j++ {
				ok := tp.AddTask(task)
				if !ok {
					t.Log(j)
				}
			}
			wg.Done()
		}()
	}
	wg.Wait()

	cntValue := atomic.LoadInt64(cnt)
	if taskCnt != cntValue {
		t.Error("want ", taskCnt, " got ", cntValue)
	}
}

TaskPoolSimple的创建比较简单,只需要提供size参数即可;TaskPool的创建需要提供TaskPoolOption,有WithTaskPoolTaskPoolSize、WithTaskPoolTaskQueueNumber、WithTaskPoolTaskQueueLength这些option

小结

gost的GenericTaskPool接口定义了AddTask、AddTaskAlways、AddTaskBalance、Close、IsClosed接口;这里有TaskPool、taskPoolSimple两个实现。

doc