Skip to content

Context 传播机制

context.Context 是 Go 并发编程中最重要的接口之一,用于传递截止时间、取消信号和请求范围的值。

Context 接口

go
type Context interface {
    // 截止时间,如果没有设置则 ok=false
    Deadline() (deadline time.Time, ok bool)

    // 返回一个 channel,当 context 被取消时关闭
    Done() <-chan struct{}

    // Done channel 关闭的原因
    // context.Canceled 或 context.DeadlineExceeded
    Err() error

    // 获取 key 对应的值
    Value(key any) any
}

四种 Context

go
// 1. Background — 根 context,永不取消
ctx := context.Background()

// 2. TODO — 占位符,表示还不确定用哪个 context
ctx := context.TODO()

// 3. WithCancel — 手动取消
ctx, cancel := context.WithCancel(parent)
defer cancel()  // 重要:必须调用,防止 goroutine 泄漏

// 4. WithTimeout — 超时自动取消
ctx, cancel := context.WithTimeout(parent, 5*time.Second)
defer cancel()

// 5. WithDeadline — 指定截止时间
deadline := time.Now().Add(5 * time.Second)
ctx, cancel := context.WithDeadline(parent, deadline)
defer cancel()

// 6. WithValue — 携带值(谨慎使用)
ctx = context.WithValue(parent, userIDKey, 42)

取消传播

go
// context 形成树形结构,父 context 取消会传播到所有子 context
func main() {
    root := context.Background()

    // 创建父 context
    parent, parentCancel := context.WithCancel(root)

    // 创建子 context
    child1, _ := context.WithTimeout(parent, 10*time.Second)
    child2, _ := context.WithCancel(parent)

    // 取消父 context
    parentCancel()

    // child1 和 child2 也会被取消
    <-child1.Done()  // 立即返回
    <-child2.Done()  // 立即返回
    fmt.Println(child1.Err())  // context canceled
}

实战:HTTP 请求超时

go
func fetchUser(ctx context.Context, userID int) (*User, error) {
    // 创建带超时的子 context
    ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
    defer cancel()

    req, err := http.NewRequestWithContext(ctx, "GET",
        fmt.Sprintf("https://api.example.com/users/%d", userID), nil)
    if err != nil {
        return nil, err
    }

    resp, err := http.DefaultClient.Do(req)
    if err != nil {
        if errors.Is(err, context.DeadlineExceeded) {
            return nil, fmt.Errorf("请求超时")
        }
        return nil, err
    }
    defer resp.Body.Close()

    var user User
    return &user, json.NewDecoder(resp.Body).Decode(&user)
}

实战:goroutine 生命周期管理

go
func startWorker(ctx context.Context) {
    go func() {
        ticker := time.NewTicker(time.Second)
        defer ticker.Stop()

        for {
            select {
            case <-ctx.Done():
                fmt.Println("worker 退出:", ctx.Err())
                return
            case t := <-ticker.C:
                fmt.Println("工作中:", t.Format("15:04:05"))
            }
        }
    }()
}

func main() {
    ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
    defer cancel()

    startWorker(ctx)

    <-ctx.Done()
    fmt.Println("主程序退出")
    time.Sleep(100 * time.Millisecond)  // 等待 worker 清理
}

实战:数据库操作

go
func (r *UserRepo) FindByID(ctx context.Context, id int) (*User, error) {
    query := "SELECT id, name, email FROM users WHERE id = $1"

    var user User
    err := r.db.QueryRowContext(ctx, query, id).Scan(
        &user.ID, &user.Name, &user.Email,
    )
    if err != nil {
        if errors.Is(err, sql.ErrNoRows) {
            return nil, ErrUserNotFound
        }
        return nil, fmt.Errorf("FindByID: %w", err)
    }
    return &user, nil
}

// 事务中使用 context
func (r *UserRepo) Transfer(ctx context.Context, fromID, toID int, amount float64) error {
    tx, err := r.db.BeginTx(ctx, nil)
    if err != nil {
        return err
    }
    defer tx.Rollback()  // 如果 Commit 成功,Rollback 是 no-op

    // 所有操作都传入 ctx,支持超时取消
    if _, err = tx.ExecContext(ctx,
        "UPDATE accounts SET balance = balance - $1 WHERE id = $2",
        amount, fromID); err != nil {
        return err
    }

    if _, err = tx.ExecContext(ctx,
        "UPDATE accounts SET balance = balance + $1 WHERE id = $2",
        amount, toID); err != nil {
        return err
    }

    return tx.Commit()
}

WithValue 的正确使用

go
// 定义私有 key 类型,避免冲突
type contextKey string

const (
    userIDKey    contextKey = "userID"
    requestIDKey contextKey = "requestID"
    traceIDKey   contextKey = "traceID"
)

// 封装 getter/setter(推荐)
func WithUserID(ctx context.Context, userID int) context.Context {
    return context.WithValue(ctx, userIDKey, userID)
}

func GetUserID(ctx context.Context) (int, bool) {
    id, ok := ctx.Value(userIDKey).(int)
    return id, ok
}

// 在中间件中注入
func authMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        userID := validateToken(r.Header.Get("Authorization"))
        ctx := WithUserID(r.Context(), userID)
        next.ServeHTTP(w, r.WithContext(ctx))
    })
}

// 在处理函数中使用
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
    userID, ok := GetUserID(r.Context())
    if !ok {
        http.Error(w, "未认证", 401)
        return
    }
    // ...
}

Context 使用规范

go
// ✅ 正确:context 作为第一个参数
func DoSomething(ctx context.Context, arg string) error { ... }

// ❌ 错误:context 放在结构体中
type Service struct {
    ctx context.Context  // 不要这样做
}

// ✅ 正确:不确定时用 context.TODO()
func legacyFunc() {
    ctx := context.TODO()  // 标记为待改进
    doWork(ctx)
}

// ✅ 正确:始终 defer cancel()
ctx, cancel := context.WithTimeout(parent, time.Second)
defer cancel()  // 即使超时触发,也要调用以释放资源

// ❌ 错误:忽略 cancel
ctx, _ = context.WithTimeout(parent, time.Second)  // 资源泄漏!

Context 传递的是什么

Context 适合传递:请求 ID、用户认证信息、链路追踪 ID、截止时间。 不适合传递:函数的可选参数、数据库连接、配置信息(用依赖注入)。

本站内容由 褚成志 整理编写,仅供学习参考