Skip to content

泛型 Generics

Go 1.18 引入泛型,是 Go 语言最重要的特性更新。合理使用泛型能大幅减少重复代码。

基础语法

go
// 泛型函数:[T constraint] 是类型参数列表
func Map[T, U any](slice []T, f func(T) U) []U {
    result := make([]U, len(slice))
    for i, v := range slice {
        result[i] = f(v)
    }
    return result
}

func Filter[T any](slice []T, pred func(T) bool) []T {
    var result []T
    for _, v := range slice {
        if pred(v) {
            result = append(result, v)
        }
    }
    return result
}

func Reduce[T, U any](slice []T, init U, f func(U, T) U) U {
    result := init
    for _, v := range slice {
        result = f(result, v)
    }
    return result
}

// 使用
nums := []int{1, 2, 3, 4, 5}
doubled := Map(nums, func(n int) int { return n * 2 })
evens := Filter(nums, func(n int) bool { return n%2 == 0 })
sum := Reduce(nums, 0, func(acc, n int) int { return acc + n })

类型约束

go
// any = interface{},无约束
func Print[T any](v T) { fmt.Println(v) }

// comparable — 可以用 == 比较
func Contains[T comparable](slice []T, item T) bool {
    for _, v := range slice {
        if v == item {
            return true
        }
    }
    return false
}

// 内置约束:constraints 包(golang.org/x/exp/constraints)
// 或直接用联合类型
type Number interface {
    int | int8 | int16 | int32 | int64 |
    uint | uint8 | uint16 | uint32 | uint64 |
    float32 | float64
}

func Sum[T Number](nums []T) T {
    var total T
    for _, n := range nums {
        total += n
    }
    return total
}

fmt.Println(Sum([]int{1, 2, 3}))         // 6
fmt.Println(Sum([]float64{1.1, 2.2}))    // 3.3

~ 底层类型约束

go
// ~ 表示"底层类型为 T 的所有类型"
type Ordered interface {
    ~int | ~int8 | ~int16 | ~int32 | ~int64 |
    ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
    ~float32 | ~float64 | ~string
}

func Min[T Ordered](a, b T) T {
    if a < b {
        return a
    }
    return b
}

// 自定义类型也满足约束
type Celsius float64
type Fahrenheit float64

fmt.Println(Min(Celsius(100), Celsius(37)))  // 37

泛型数据结构

泛型栈

go
type Stack[T any] struct {
    items []T
}

func (s *Stack[T]) Push(item T) {
    s.items = append(s.items, item)
}

func (s *Stack[T]) Pop() (T, bool) {
    if len(s.items) == 0 {
        var zero T
        return zero, false
    }
    n := len(s.items) - 1
    item := s.items[n]
    s.items = s.items[:n]
    return item, true
}

func (s *Stack[T]) Peek() (T, bool) {
    if len(s.items) == 0 {
        var zero T
        return zero, false
    }
    return s.items[len(s.items)-1], true
}

func (s *Stack[T]) Len() int { return len(s.items) }

// 使用
intStack := &Stack[int]{}
intStack.Push(1)
intStack.Push(2)
v, _ := intStack.Pop()  // 2

strStack := &Stack[string]{}
strStack.Push("hello")

泛型有序 Map

go
type OrderedMap[K Ordered, V any] struct {
    keys   []K
    values map[K]V
}

func NewOrderedMap[K Ordered, V any]() *OrderedMap[K, V] {
    return &OrderedMap[K, V]{values: make(map[K]V)}
}

func (m *OrderedMap[K, V]) Set(key K, value V) {
    if _, exists := m.values[key]; !exists {
        m.keys = append(m.keys, key)
        sort.Slice(m.keys, func(i, j int) bool {
            return m.keys[i] < m.keys[j]
        })
    }
    m.values[key] = value
}

func (m *OrderedMap[K, V]) Get(key K) (V, bool) {
    v, ok := m.values[key]
    return v, ok
}

func (m *OrderedMap[K, V]) Keys() []K {
    return m.keys
}

实用泛型工具函数

go
// Keys — 获取 map 的所有键
func Keys[K comparable, V any](m map[K]V) []K {
    keys := make([]K, 0, len(m))
    for k := range m {
        keys = append(keys, k)
    }
    return keys
}

// Values — 获取 map 的所有值
func Values[K comparable, V any](m map[K]V) []V {
    vals := make([]V, 0, len(m))
    for _, v := range m {
        vals = append(vals, v)
    }
    return vals
}

// Must — 包装返回 (T, error) 的函数,panic on error
func Must[T any](v T, err error) T {
    if err != nil {
        panic(err)
    }
    return v
}

// Ptr — 获取值的指针(常用于测试)
func Ptr[T any](v T) *T { return &v }

// Coalesce — 返回第一个非零值
func Coalesce[T comparable](values ...T) T {
    var zero T
    for _, v := range values {
        if v != zero {
            return v
        }
    }
    return zero
}

// 使用
name := Coalesce(user.Nickname, user.Username, "匿名用户")

泛型约束组合

go
// 接口可以组合多个约束
type Stringer interface {
    String() string
}

type PrintableNumber interface {
    Number
    Stringer
}

// 方法约束
type Container[T any] interface {
    Add(T)
    Remove() (T, bool)
    Len() int
}

// 泛型函数接受泛型接口
func Drain[T any](c Container[T]) []T {
    var items []T
    for c.Len() > 0 {
        if item, ok := c.Remove(); ok {
            items = append(items, item)
        }
    }
    return items
}

泛型的限制

go
// 1. 不支持泛型方法(只支持泛型类型上的方法)
// ❌ 错误
func (s *Stack[T]) Map[U any](f func(T) U) []U { ... }

// ✅ 正确:用包级函数
func StackMap[T, U any](s *Stack[T], f func(T) U) []U { ... }

// 2. 类型参数不能用于 switch
func process[T any](v T) {
    // ❌ 不支持
    // switch v.(type) { ... }

    // ✅ 用 reflect 或 any 转换
    switch any(v).(type) {
    case int:
        fmt.Println("int")
    case string:
        fmt.Println("string")
    }
}

// 3. 不支持协变/逆变
// []int 不能赋值给 []any

何时使用泛型

  • 编写通用数据结构(栈、队列、集合)
  • 编写通用算法(Map/Filter/Reduce)
  • 消除重复的类型断言代码

不要为了泛型而泛型。如果 interface{} 或代码生成更简单,就用那个。

本站内容由 褚成志 整理编写,仅供学习参考