Context 传播机制
context.Context是 Go 并发编程中最重要的接口之一,用于传递截止时间、取消信号和请求范围的值。
Context 接口
go
type Context interface {
// 截止时间,如果没有设置则 ok=false
Deadline() (deadline time.Time, ok bool)
// 返回一个 channel,当 context 被取消时关闭
Done() <-chan struct{}
// Done channel 关闭的原因
// context.Canceled 或 context.DeadlineExceeded
Err() error
// 获取 key 对应的值
Value(key any) any
}四种 Context
go
// 1. Background — 根 context,永不取消
ctx := context.Background()
// 2. TODO — 占位符,表示还不确定用哪个 context
ctx := context.TODO()
// 3. WithCancel — 手动取消
ctx, cancel := context.WithCancel(parent)
defer cancel() // 重要:必须调用,防止 goroutine 泄漏
// 4. WithTimeout — 超时自动取消
ctx, cancel := context.WithTimeout(parent, 5*time.Second)
defer cancel()
// 5. WithDeadline — 指定截止时间
deadline := time.Now().Add(5 * time.Second)
ctx, cancel := context.WithDeadline(parent, deadline)
defer cancel()
// 6. WithValue — 携带值(谨慎使用)
ctx = context.WithValue(parent, userIDKey, 42)取消传播
go
// context 形成树形结构,父 context 取消会传播到所有子 context
func main() {
root := context.Background()
// 创建父 context
parent, parentCancel := context.WithCancel(root)
// 创建子 context
child1, _ := context.WithTimeout(parent, 10*time.Second)
child2, _ := context.WithCancel(parent)
// 取消父 context
parentCancel()
// child1 和 child2 也会被取消
<-child1.Done() // 立即返回
<-child2.Done() // 立即返回
fmt.Println(child1.Err()) // context canceled
}实战:HTTP 请求超时
go
func fetchUser(ctx context.Context, userID int) (*User, error) {
// 创建带超时的子 context
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET",
fmt.Sprintf("https://api.example.com/users/%d", userID), nil)
if err != nil {
return nil, err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return nil, fmt.Errorf("请求超时")
}
return nil, err
}
defer resp.Body.Close()
var user User
return &user, json.NewDecoder(resp.Body).Decode(&user)
}实战:goroutine 生命周期管理
go
func startWorker(ctx context.Context) {
go func() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
fmt.Println("worker 退出:", ctx.Err())
return
case t := <-ticker.C:
fmt.Println("工作中:", t.Format("15:04:05"))
}
}
}()
}
func main() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
startWorker(ctx)
<-ctx.Done()
fmt.Println("主程序退出")
time.Sleep(100 * time.Millisecond) // 等待 worker 清理
}实战:数据库操作
go
func (r *UserRepo) FindByID(ctx context.Context, id int) (*User, error) {
query := "SELECT id, name, email FROM users WHERE id = $1"
var user User
err := r.db.QueryRowContext(ctx, query, id).Scan(
&user.ID, &user.Name, &user.Email,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("FindByID: %w", err)
}
return &user, nil
}
// 事务中使用 context
func (r *UserRepo) Transfer(ctx context.Context, fromID, toID int, amount float64) error {
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback() // 如果 Commit 成功,Rollback 是 no-op
// 所有操作都传入 ctx,支持超时取消
if _, err = tx.ExecContext(ctx,
"UPDATE accounts SET balance = balance - $1 WHERE id = $2",
amount, fromID); err != nil {
return err
}
if _, err = tx.ExecContext(ctx,
"UPDATE accounts SET balance = balance + $1 WHERE id = $2",
amount, toID); err != nil {
return err
}
return tx.Commit()
}WithValue 的正确使用
go
// 定义私有 key 类型,避免冲突
type contextKey string
const (
userIDKey contextKey = "userID"
requestIDKey contextKey = "requestID"
traceIDKey contextKey = "traceID"
)
// 封装 getter/setter(推荐)
func WithUserID(ctx context.Context, userID int) context.Context {
return context.WithValue(ctx, userIDKey, userID)
}
func GetUserID(ctx context.Context) (int, bool) {
id, ok := ctx.Value(userIDKey).(int)
return id, ok
}
// 在中间件中注入
func authMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userID := validateToken(r.Header.Get("Authorization"))
ctx := WithUserID(r.Context(), userID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// 在处理函数中使用
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
userID, ok := GetUserID(r.Context())
if !ok {
http.Error(w, "未认证", 401)
return
}
// ...
}Context 使用规范
go
// ✅ 正确:context 作为第一个参数
func DoSomething(ctx context.Context, arg string) error { ... }
// ❌ 错误:context 放在结构体中
type Service struct {
ctx context.Context // 不要这样做
}
// ✅ 正确:不确定时用 context.TODO()
func legacyFunc() {
ctx := context.TODO() // 标记为待改进
doWork(ctx)
}
// ✅ 正确:始终 defer cancel()
ctx, cancel := context.WithTimeout(parent, time.Second)
defer cancel() // 即使超时触发,也要调用以释放资源
// ❌ 错误:忽略 cancel
ctx, _ = context.WithTimeout(parent, time.Second) // 资源泄漏!Context 传递的是什么
Context 适合传递:请求 ID、用户认证信息、链路追踪 ID、截止时间。 不适合传递:函数的可选参数、数据库连接、配置信息(用依赖注入)。