如何创建支持对称矩阵特性的 NumPy 子类

本文介绍如何通过子类化 `numpy.ndarray` 实现一个轻量、安全的对称矩阵类,自动强制对称性,并在赋值时保持结构不变;同时建议利用 `np.linalg.eigh` 而非缓存 `u` 和 `d` 属性,以兼顾正确性与内存效率。

要构建专用于对称矩阵的 NumPy 子类,核心在于两点:构造时自动对称化输入,以及赋值时同步更新对称位置。直接继承 np.ndarray 并重写 __new__ 与 __setitem__ 是最简洁可靠的方案(避免使用已弃用的 __array_finalize__ 或复杂钩子)。

以下是一个生产就绪的 SymmetricArray 实现:

import numpy as np

class SymmetricArray(np.ndarray):
    def __new__(cls, input_array):
        # 强制最后一维为方阵(支持批量张量,如 (N, D, D))
        assert input_array.ndim >= 2 and input_array.shape[-1] == input_array.shape[-2], \
            "Last two dimensions must be equal for symmetry"

        # 计算对称部分:(A + A^T) / 2
        axes = list(range(input_array.ndim - 2)) + [-1, -2]
        transposed = input_array.transpose(axes)
        sym_arr = 0.5 * (input_array + transposed)
        return sym_arr.view(cls)

    def __setitem__(self, key, value):
        # 标准化索引为 tuple,补全省略的维度(如 a[1] → a[1, :])
        if not isinstance(key, tuple):
            key = (key,)
        if len(key) < self.ndim:
            key += (slice(None),) * (self.ndim - len(key))

        # 构造对称索引:交换最后两个轴的下标
        key_t = key[:-2] + (key[-1], key[-2])

        # 确保 value 也对称化(尤其当 value 是矩阵时)
        value = np.asarray(value)
        if value.ndim >= 2 and value.shape[-1] == value.shape[-2]:
            axes_v = list(range(value.ndim - 2)) + [-1, -2]
            value_t = value.transpose(axes_v)
        else:
            value_t = value  # 标量或向量无需转置

        # 同步写入原位置与对称位置
        super().__setitem__(key, value)
        super().__setitem__(key_t, value_t)

关键特性说明

  • ✅ 支持多维广播:如 (5, 4, 4) 批量对称矩阵,仅对最后两维施加对称约束;
  • ✅ 安全索引:a[:, 1] = 0 会自动设 a[1, :] = 0,保持对称;
  • ✅ 兼容 NumPy 通用操作:切片、广播、ufunc 均可直接使用;
  • ✅ 不冗余存储 U/D:强烈推荐按需调用 np.linalg.eigh() ——
    S = SymmetricArray([[2, 1], [1, 3]])
    D, U = np.linalg.eigh(S)  # 正确、稳定、支持实对称矩阵专属算法

⚠️ 注意事项

  • ❌ 避免在 __init__ 中赋值(ndarray 子类初始化逻辑在 __new__ 中完成);
  • ❌ 不要缓存 U/D 为实例属性:矩阵内容可能被后续 __setitem__ 修改,导致特征分解过期且难以维护一致性;
  • ⚠️ 若需频繁访问特征值,可封装为只读属性(但内部仍每次计算):
    @property
    def eigenvalues(self):
        return np.linalg.eigh(self)[0]

该设计平衡了简洁性、健壮性与 NumPy 生态兼容性,是构建领域专用数组的典型范式。