Skip to content

神奇的装饰器函数

核心作用:优雅的“包装”

装饰器(Decorator)是 Python 的一个核心特性,其本质是一个接受函数作为参数并返回一个新函数的可调用对象(通常是函数)。 它的核心作用是:在不修改被装饰函数(或类)原始代码的情况下,为其动态地添加额外的功能。 这完美遵循了软件开发中的 “开放-封闭”原则:

  • 对扩展开放:可以轻松地为现有功能添加新行为。
  • 对修改封闭:无需改动原有的、已经工作正常的代码。 你可以把它想象成一个礼物包装。礼物本身(核心功能)没有变,但包装纸(装饰器)赋予了它更漂亮的外观或更多的含义(附加功能)。

常用场景

装饰器的应用极其广泛,以下是一些最常见和实用的场景:

日志记录 (Logging)

  • 作用:自动记录函数的执行信息,如函数名、传入的参数、返回值、执行时间等。这对于调试和监控程序运行状态非常有用。
  • 场景:你有一个处理用户请求的函数,你想知道每次是谁、在什么时候、调用了这个函数以及处理结果是什么,但又不想在每个函数开头和结尾都写print或logging.info。
python

import functools
import logging

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def log_function_call(func):
    """记录函数调用信息的装饰器"""
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        logger.info(f"调用函数: {func.__name__}, 参数: args={args}, kwargs={kwargs}")
        result = func(*args, **kwargs)
        logger.info(f"函数 {func.__name__} 执行完成, 返回值: {result}")
        return result
    return wrapper

# 使用示例
@log_function_call
def add_numbers(a, b):
    """将两个数字相加"""
    return a + b

@log_function_call
def greet(name, greeting="Hello"):
    """打招呼函数"""
    return f"{greeting}, {name}!"

# 测试
if __name__ == "__main__":
    print(add_numbers(5, 3))
    print(greet("Alice", greeting="Hi"))

性能分析 (Performance Benchmarking / Timing)

  • 作用:测量并输出函数的执行时间,帮助你找到代码的性能瓶颈。
  • 场景:你想优化一段代码,但不确定到底是哪个函数运行缓慢。你可以用一个 @timer 装饰器来快速装饰几个可疑的函数,从而精确地测量它们的运行耗时。
python
import functools
import time

def timer(func):
    """测量函数执行时间的装饰器"""
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        start_time = time.perf_counter()
        result = func(*args, **kwargs)
        end_time = time.perf_counter()
        execution_time = end_time - start_time
        print(f"函数 {func.__name__} 执行耗时: {execution_time:.4f} 秒")
        return result
    return wrapper

# 使用示例
@timer
def slow_function():
    """模拟一个耗时操作"""
    time.sleep(2)
    return "任务完成"

@timer
def fast_function():
    """快速执行的函数"""
    return "快速完成"

# 测试
if __name__ == "__main__":
    print(slow_function())
    print(fast_function())

计时 (Timing)

  • 作用:精确测量函数或代码块的执行时间,提供简洁直观的性能分析信息。
  • 场景:在优化算法性能时,你需要比较不同实现方案的耗时差异。或者在开发过程中快速了解函数的执行效率,帮助识别性能瓶颈。
python
import functools
import time
from contextlib import contextmanager

def timing(name=None):
    """精确计时装饰器"""
    def decorator(func):
        func_name = name or func.__name__
        
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            start_time = time.perf_counter()
            result = func(*args, **kwargs)
            end_time = time.perf_counter()
            duration = end_time - start_time
            print(f"⏱️ {func_name} 执行耗时: {duration:.4f}s")
            return result
        
        return wrapper
    return decorator

@contextmanager
def time_block(name):
    """代码块计时上下文管理器"""
    start_time = time.perf_counter()
    try:
        yield
    finally:
        end_time = time.perf_counter()
        duration = end_time - start_time
        print(f"⏱️ {name} 代码块耗时: {duration:.4f}s")

# 使用示例
@timing("斐波那契计算")
def fibonacci_recursive(n):
    """递归计算斐波那契数列"""
    if n <= 1:
        return n
    return fibonacci_recursive(n-1) + fibonacci_recursive(n-2)

@timing("快速计算")
def fibonacci_iterative(n):
    """迭代计算斐波那契数列"""
    if n <= 1:
        return n
    a, b = 0, 1
    for _ in range(2, n + 1):
        a, b = b, a + b
    return b

@timing()
def process_data():
    """模拟数据处理"""
    time.sleep(0.1)  # 模拟处理延时
    return "处理完成"

# 测试
if __name__ == "__main__":
    print("=== 计时装饰器测试 ===")
    
    # 比较不同算法性能
    print("\n算法性能比较:")
    result1 = fibonacci_recursive(8)  # 较小的数值避免过长时间
    result2 = fibonacci_iterative(30)
    print(f"结果: {result1}, {result2}")
    
    # 函数执行计时
    print("\n函数执行计时:")
    process_data()
    
    # 代码块计时
    print("\n代码块计时:")
    with time_block("批量计算"):
        total = 0
        for i in range(1000000):
            total += i
        print(f"计算结果: {total}")

权限校验与认证 (Authentication & Authorization)

  • 作用:在函数执行前,检查用户是否已登录或是否拥有执行该操作的权限。如果校验失败,则阻止函数执行并跳转到登录页或返回错误信息。
  • 场景:在 Web 开发(如 Django, Flask 框架)中极其常见。例如,@login_required 装饰器可以确保只有登录用户才能访问某个视图函数(页面)。
python
import functools

# 模拟用户数据库
users = {
    "admin": {"password": "admin123", "role": "admin"},
    "user1": {"password": "pass123", "role": "user"},
    "guest": {"password": "guest", "role": "guest"}
}

# 当前登录用户(模拟会话)
current_user = None

def login_required(required_role="user"):
    """要求用户登录的装饰器"""
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            if current_user is None:
                raise PermissionError("请先登录!")
            
            user_role = users.get(current_user, {}).get("role")
            if user_role is None:
                raise PermissionError("用户不存在或角色未定义")
            
            # 修正权限检查逻辑 - 使用角色层级系统
            role_hierarchy = {"guest": 0, "user": 1, "admin": 2}
            required_level = role_hierarchy.get(required_role, 1)
            user_level = role_hierarchy.get(user_role, 0)
            
            if user_level < required_level:
                raise PermissionError(f"需要 {required_role} 权限,当前用户权限: {user_role}")
            
            return func(*args, **kwargs)
        return wrapper
    return decorator

# 使用示例
@login_required()
def view_profile():
    """查看用户资料"""
    return f"欢迎查看您的资料, {current_user}!"

@login_required("admin")
def admin_panel():
    """管理员面板"""
    return "欢迎来到管理员面板!"

# 测试
if __name__ == "__main__":
    # 模拟登录
    current_user = "user1"
    print(view_profile())
    
    try:
        print(admin_panel())
    except PermissionError as e:
        print(f"权限错误: {e}")
    
    # 测试未登录情况
    current_user = None
    try:
        print(view_profile())
    except PermissionError as e:
        print(f"权限错误: {e}")

输入验证与预处理 (Input Validation & Sanitization)

  • 作用:在函数执行前,检查传入的参数是否合法(如类型是否正确、数值是否在有效范围内),并对参数进行清理或转换。
  • 场景:一个计算圆面积的函数需要确保半径是一个非负数。你可以用一个装饰器来验证参数,如果半径是负数则自动抛出异常或返回错误,让核心函数专注于纯计算逻辑。
python
import functools

def validate_input(**validations):
    """验证函数参数的装饰器"""
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            # 将args和kwargs合并为字典以便验证
            all_args = dict(zip(func.__code__.co_varnames, args))
            all_args.update(kwargs)
            
            # 验证每个参数
            for param_name, validation in validations.items():
                if param_name in all_args:
                    value = all_args[param_name]
                    if not validation(value):
                        raise ValueError(f"参数 {param_name} 的值 {value} 不符合验证规则")
            
            return func(*args, **kwargs)
        return wrapper
    return decorator

# 验证函数
def is_positive(number):
    return number > 0

def is_non_empty_string(text):
    return isinstance(text, str) and len(text.strip()) > 0

def is_valid_age(age):
    return isinstance(age, int) and 0 < age < 150

# 使用示例
@validate_input(radius=is_positive, height=is_positive)
def calculate_cylinder_volume(radius, height):
    """计算圆柱体体积"""
    return 3.14159 * radius ** 2 * height

@validate_input(name=is_non_empty_string, age=is_valid_age)
def create_user_profile(name, age):
    """创建用户资料"""
    return f"用户: {name}, 年龄: {age}"

# 测试
if __name__ == "__main__":
    try:
        print(calculate_cylinder_volume(5, 10))
        print(create_user_profile("Alice", 25))
        
        # 测试无效输入
        print(calculate_cylinder_volume(-5, 10))  # 会抛出异常
    except ValueError as e:
        print(f"输入验证错误: {e}")

重试机制 (Retry Mechanism)

  • 作用:当函数执行失败(如遇到网络波动、数据库暂时无法连接)时,自动重试指定的次数。
  • 场景:调用一个第三方 API 接口,但这个接口偶尔不稳定。你可以用一个 @retry(times=3) 装饰器,让它在失败后自动重试 3 次,而不是立即报错。
python
import functools
import time
import random

def retry(max_attempts=3, delay=1, backoff=2, exceptions=(Exception,)):
    """函数执行失败时自动重试的装饰器"""
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            attempts = 0
            current_delay = delay
            
            while attempts < max_attempts:
                try:
                    return func(*args, **kwargs)
                except exceptions as e:
                    attempts += 1
                    if attempts == max_attempts:
                        print(f"重试 {max_attempts} 次后仍然失败: {e}")
                        raise
                    
                    print(f"第 {attempts} 次尝试失败: {e}, {current_delay}秒后重试...")
                    time.sleep(current_delay)
                    current_delay *= backoff  # 指数退避
        return wrapper
    return decorator

# 使用示例
@retry(max_attempts=3, delay=1, exceptions=(ValueError,))
def unreliable_api_call():
    """模拟不可靠的API调用"""
    if random.random() < 0.7:  # 70%的概率失败
        raise ValueError("API调用失败: 连接超时")
    return "API调用成功!"

@retry(max_attempts=5, delay=2, exceptions=(RuntimeError,))
def process_data():
    """模拟数据处理"""
    if random.random() < 0.5:
        raise RuntimeError("数据处理错误")
    return "数据处理完成"

# 测试
if __name__ == "__main__":
    try:
        print(unreliable_api_call())
    except Exception as e:
        print(f"最终失败: {e}")
    
    try:
        print(process_data())
    except Exception as e:
        print(f"最终失败: {e}")

注册函数 (Function Registration)

  • 作用:自动将函数注册到一个中央仓库(如插件系统、路由系统)中,而不需要显式地调用注册代码。
  • 场景:在 Web 框架中,用 @app.route(‘/url’) 装饰器将一个函数注册为处理特定 URL 请求的控制器。框架在启动时会自动收集所有被此装饰器装饰的函数来构建路由映射表。
python
import functools

class PluginRegistry:
    """插件注册表"""
    def __init__(self):
        self.plugins = {}
    
    def register(self, name=None):
        """注册插件的装饰器"""
        def decorator(func):
            plugin_name = name or func.__name__
            self.plugins[plugin_name] = func
            return func
        return decorator
    
    def execute_plugin(self, name, *args, **kwargs):
        """执行指定插件"""
        if name not in self.plugins:
            raise ValueError(f"插件 {name} 未注册")
        return self.plugins[name](*args, **kwargs)
    
    def list_plugins(self):
        """列出所有已注册插件"""
        return list(self.plugins.keys())

# 创建注册表实例
registry = PluginRegistry()

# 使用装饰器注册插件
@registry.register("text_processor")
def process_text(text):
    """处理文本的插件"""
    return text.upper()

@registry.register("math_operation")
def calculate_square(x):
    """计算平方的插件"""
    return x ** 2

@registry.register()  # 使用函数名作为插件名
def custom_format(data):
    """自定义格式化插件"""
    return f"格式化结果: {data}"

# 测试
if __name__ == "__main__":
    print("已注册插件:", registry.list_plugins())
    
    # 执行插件
    print(registry.execute_plugin("text_processor", "hello world"))
    print(registry.execute_plugin("math_operation", 5))
    print(registry.execute_plugin("custom_format", "测试数据"))

缓存 (Memoization)

  • 作用:将函数的结果缓存起来。当用相同的参数再次调用时,直接返回缓存的结果,而不需要重新计算,显著提高计算密集型函数的性能。
  • 场景:计算斐波那契数列。使用 @functools.lru_cache 这个内置装饰器可以避免大量的重复计算。
python
import functools
import time

def cache_results(max_size=128):
    """缓存函数结果的装饰器"""
    def decorator(func):
        cache = {}
        access_order = []  # 记录访问顺序用于LRU策略
        
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            # 创建缓存键
            cache_key = (args, frozenset(kwargs.items()) if kwargs else frozenset())
            
            if cache_key in cache:
                print(f"缓存命中: {func.__name__}{args}")
                # 更新访问顺序
                access_order.remove(cache_key)
                access_order.append(cache_key)
                return cache[cache_key]
            
            # 如果缓存满了,删除最久未使用的条目
            if len(cache) >= max_size:
                oldest_key = access_order.pop(0)
                del cache[oldest_key]
            
            result = func(*args, **kwargs)
            cache[cache_key] = result
            access_order.append(cache_key)
            print(f"结果已缓存: {func.__name__}{args}")
            return result
        
        # 提供清理缓存的方法
        def clear_cache():
            cache.clear()
            access_order.clear()
        
        wrapper.clear_cache = clear_cache
        wrapper.cache_info = lambda: {"cache_size": len(cache), "max_size": max_size}
        return wrapper
    return decorator

# 使用示例
@cache_results(max_size=10)
def fibonacci(n):
    """计算斐波那契数列"""
    if n <= 1:
        return n
    return fibonacci(n-1) + fibonacci(n-2)

@cache_results()
def expensive_calculation(x, y):
    """模拟耗时计算"""
    time.sleep(1)  # 模拟计算耗时
    return x * y + x + y

# 测试
if __name__ == "__main__":
    # 测试斐波那契(会显示缓存效果)
    print("fibonacci(10) =", fibonacci(10))
    print("fibonacci(10) =", fibonacci(10))  # 这次会从缓存获取
    
    # 测试耗时计算
    start = time.time()
    result1 = expensive_calculation(5, 3)
    end = time.time()
    print(f"第一次计算耗时: {end-start:.2f}秒")
    
    start = time.time()
    result2 = expensive_calculation(5, 3)  # 这次会很快
    end = time.time()
    print(f"第二次计算耗时: {end-start:.4f}秒")
    
    print(f"结果: {result1}, {result2}")

事务管理 (Transaction Management)

  • 作用:在函数执行前开始一个数据库事务,如果函数执行成功则提交事务,如果发生异常则回滚事务。
  • 场景:确保数据库操作的原子性。比如一个函数需要更新两张表,要么都成功,要么都失败,用装饰器可以很好地管理这个边界。
python
import functools

class DatabaseTransaction:
    """模拟数据库事务"""
    def __init__(self):
        self.in_transaction = False
        self.changes = []
    
    def begin(self):
        """开始事务"""
        if self.in_transaction:
            raise RuntimeError("事务已在进行中")
        self.in_transaction = True
        self.changes = []
        print("事务开始")
    
    def commit(self):
        """提交事务"""
        if not self.in_transaction:
            raise RuntimeError("没有进行中的事务")
        print(f"事务提交: 应用 {len(self.changes)} 个更改")
        self.in_transaction = False
        self.changes = []
    
    def rollback(self):
        """回滚事务"""
        if not self.in_transaction:
            raise RuntimeError("没有进行中的事务")
        print(f"事务回滚: 撤销 {len(self.changes)} 个更改")
        self.in_transaction = False
        self.changes = []
    
    def execute(self, query, params=None):
        """执行SQL语句"""
        if not self.in_transaction:
            raise RuntimeError("必须在事务中执行操作")
        
        change = f"执行: {query} {params or ''}"
        self.changes.append(change)
        print(change)
        return True

# 创建数据库实例
db = DatabaseTransaction()

def with_transaction(func):
    """管理数据库事务的装饰器"""
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        try:
            db.begin()
            result = func(*args, **kwargs)
            db.commit()
            return result
        except Exception as e:
            db.rollback()
            print(f"事务回滚原因: {e}")
            raise
    return wrapper

# 使用示例
@with_transaction
def transfer_money(from_account, to_account, amount):
    """转账操作"""
    # 模拟数据库操作
    db.execute("UPDATE accounts SET balance = balance - ? WHERE id = ?", (amount, from_account))
    db.execute("UPDATE accounts SET balance = balance + ? WHERE id = ?", (amount, to_account))
    
    # 模拟业务逻辑检查
    if amount > 1000:
        raise ValueError("单笔转账金额不能超过1000")
    
    db.execute("INSERT INTO transactions VALUES (?, ?, ?)", (from_account, to_account, amount))
    return "转账成功"

@with_transaction
def create_user_account(username, email):
    """创建用户账户"""
    db.execute("INSERT INTO users (username, email) VALUES (?, ?)", (username, email))
    db.execute("INSERT INTO accounts (user_id, balance) VALUES (1, 100)")  # 简化ID处理
    return f"用户 {username} 创建成功"

# 测试
if __name__ == "__main__":
    try:
        print(transfer_money(1, 2, 500))
        print("---")
        print(transfer_money(3, 4, 1500))  # 这个会失败并回滚
    except Exception as e:
        print(f"操作失败: {e}")
    
    print("---")
    try:
        print(create_user_account("alice", "alice@example.com"))
    except Exception as e:
        print(f"操作失败: {e}")

总结

装饰器通过一种非常清晰、非侵入式的方式,将核心业务逻辑(函数本身)与横切关注点(辅助性功能,如日志、缓存、认证)分离开,使得代码更加模块化、可读性更强、也更易于维护。 这些示例涵盖了装饰器的各种实际应用场景,每个都可以直接运行并看到效果。你可以根据需要修改和扩展这些装饰器。

Released under the Apache 2.0 License.