Python functools模块高阶函数实用指南

Python functools模块高阶函数实用指南

functools.partial固定函数的部分参数:

from functools import partial

def power(base, exponent):
return base ** exponent

square = partial(power, exponent=2)
cube = partial(power, exponent=3)

print(square(5)) # 25
print(cube(5)) # 125

partial内部创建一个新函数对象,调用时合并存储的args和调用时传入的args:

print(square.func) #
print(square.args) # ()
print(square.keywords) # {'exponent': 2}

partial对象的属性暴露了原始函数和冻结的参数。

partialmethod用于绑定方法的参数:

from functools import partialmethod

class Cell:
def __init__(self):
self._alive = False

def set_state(self, state):
self._alive = state

set_alive = partialmethod(set_state, True)
set_dead = partialmethod(set_state, False)

c = Cell()
c.set_alive()
print(c._alive) # True

partialmethod在类体中使用,绑定的是实例方法。

lru_cache的底层实现使用字典缓存函数返回值:

from functools import lru_cache
import time

@lru_cache(maxsize=128)
def fib(n):
if n < 2:
return n
return fib(n-1) + fib(n-2)

start = time.perf_counter()
print(fib(100))
print(f"Time: {time.perf_counter() - start:.6f}s")
print(fib.cache_info()) # CacheInfo(hits=98, misses=101, maxsize=128, currsize=101)

缓存命中时直接返回,避免重复计算。fib.cache_info()返回命中次数、未命中次数、当前大小等统计信息。

lru_cache的适用条件:

@lru_cache(maxsize=None)
def expensive_db_query(user_id):
# 数据库查询
pass

函数参数必须是可哈希的。函数应该是纯函数(无副作用)。maxsize=None时缓存无限增长。

cached_property将方法调用结果缓存在实例上:

from functools import cached_property
import time

class DataProcessor:
def __init__(self, data):
self.data = data

@cached_property
def processed(self):
print("Processing data...")
time.sleep(1)
return sum(self.data)

dp = DataProcessor([1, 2, 3, 4, 5])
print(dp.processed) # 第一次调用,执行处理
print(dp.processed) # 直接返回缓存值

cached_property在首次访问时计算结果并存储在实例__dict__中。同名的实例属性覆盖描述符。

cached_property的线程安全:

class ThreadSafeProcessor:
def __init__(self, data):
self.data = data

@cached_property
def processed(self):
return sum(self.data) * 2

cached_property未使用锁。多个线程同时首次访问时可能重复计算。

total_ordering自动完善比较方法:

from functools import total_ordering

@total_ordering
class Person:
def __init__(self, name, age):
self.name = name
self.age = age

def __eq__(self, other):
return self.age == other.age

def __lt__(self, other):
return self.age < other.age

p1 = Person("Alice", 30)
p2 = Person("Bob", 25)
p3 = Person("Charlie", 30)

print(p1 > p2) # True
print(p1 >= p3) # True
print(p2 <= p1) # True

total_ordering根据__eq__和__lt__自动生成__le__、__gt__、__ge__。

singledispatch实现单分派泛型函数:

from functools import singledispatch

@singledispatch
def process(obj):
raise NotImplementedError("Unsupported type")

@process.register(int)
def _(obj):
return f"Integer: {obj}"

@process.register(str)
def _(obj):
return f"String: {obj}"

@process.register(list)
@process.register(tuple)
def _(obj):
return f"Sequence: {len(obj)} elements"

print(process(42)) # Integer: 42
print(process("hello")) # String: hello
print(process([1,2,3])) # Sequence: 3 elements

singledispatch根据第一个参数的类型分派到不同的实现。dispatch函数在方法注册表中查找类型。

singledispatchmethod用于类方法:

from functools import singledispatchmethod

class Formatter:
@singledispatchmethod
def format(self, arg):
raise NotImplementedError

@format.register(int)
def _(self, arg):
return f"int: {arg}"

@format.register(str)
def _(self, arg):
return f"str: {arg}"

f = Formatter()
print(f.format(42)) # int: 42
print(f.format("hello")) # str: hello

singledispatchmethod支持self参数,分派基于第二个参数(实际数据)。

wraps保持装饰函数的元信息:

from functools import wraps

def my_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper

@my_decorator
def example():
"""Docstring"""
pass

print(example.__name__) # example(没有wraps时是wrapper)
print(example.__doc__) # Docstring(没有wraps时是None)

wraps复制__module__、__name__、__qualname__、__doc__、__annotations__、__dict__和__wrapped__。

update_wrapper手动复制属性:

from functools import update_wrapper

def wrapper(func):
def inner(*args, **kwargs):
return func(*args, **kwargs)
update_wrapper(inner, func, assigned=['__name__', '__doc__'])
return inner

reduce从左到右累积序列元素:

from functools import reduce

result = reduce(lambda x, y: x * y, [1, 2, 3, 4, 5])
print(result) # 120

fac = lambda n: reduce(lambda a, b: a * b, range(1, n + 1), 1)

reduce的行为是:前两个元素应用函数,结果与第三个元素应用函数,以此类推。

cmp_to_key将旧式比较函数转为key函数:

from functools import cmp_to_key

def compare(a, b):
return -1 if a < b else 1 if a > b else 0

sorted_numbers = sorted([3, 1, 4, 1, 5], key=cmp_to_key(compare))

cache是lru_cache的无限制版本(Python 3.9+):

from functools import cache
import time

@cache
def fibonacci(n):
if n < 2:
return n
return fibonacci(n-1) + fibonacci(n-2)

start = time.perf_counter()
fibonacci(100)
print(f"Time: {time.perf_counter() - start:.6f}s")