Skip to content

一个有关fft-ssa乘法的python代码,取自原mul_fft.c的注释 #89

@HJimmyK

Description

@HJimmyK

This file implements the FFT-SSA algorithm. We separately calculate the results of two large numbers modulo 2^n-1 and
2^n+1, which correspond to the results of the Mersenne transform and Fermat transform, respectively. Then, we use the
Chinese remainder theorem to combine the results. The following article may be helpful for you regarding this algorithm:
算法优化深坑 -- 斐波那契数字 & 大数字乘法 - nyasyamorina的文章 - 知乎
https://zhuanlan.zhihu.com/p/21525013391

A python code for the FFT-SSA algorithm:

    import math

    # ---------- 辅助函数 ----------
    def extract_bits(x, start, length):
        """从整数 x 中提取从第 start 位(最低位 LSB=0)开始、共 length 位的比特片段。

        数学含义:
            - x: 被拆分的大整数(被表示为二进制)
            - start: 起始比特位置(从右往左数,0-indexed)
            - length: 要提取的连续比特长度
            返回值:x 的 [start, start+length) 比特组成的整数
        """
        if length <= 0:
            return 0
        mask = (1 << length) - 1  # 构造一个 length 位全 1 的掩码
        return (x >> start) & mask

    def bit_length(x):
        """返回整数 x 的二进制表示所需的最少比特数(0 特殊处理为 0)。"""
        return x.bit_length() if x != 0 else 0

    def next_power_of_two(x):
        """返回大于等于 x 的最小 2 的幂(若 x=0 则返回 1)。

        数学用途:用于确定 FFT 长度或内存对齐等场景。
        """
        return 1 if x == 0 else 1 << (x - 1).bit_length()

    # ---------- 费马 FFT (mod 2^n + 1) ----------
    def fermat_fft(poly, n, k, depth=0):
        """
        在环 Z/(2^n + 1) 上执行原地快速傅里叶变换(FFT)。

        参数数学含义:
            - poly: 输入多项式系数列表,长度 K = 2^k。
                    每个系数是一个整数,且满足 < 2^(n / K),即每个系数占 M = n/K 比特。
            - n: 费马模数的指数,模数为 F = 2^n + 1。
            - k: FFT 的阶数,总点数 K = 2^k。
            - depth: 递归深度(调试用,不影响计算)

        环结构说明:
            在 Z/(2^n + 1) 中,有 2^n ≡ -1 (mod 2^n + 1),
            因此 x^K = x^{2^k} 对应于 2^{n} ≡ -1,故 x^{2K} ≡ 1,存在单位根。
        """
        K = 1 << k          # 总点数 = 2^k
        if K == 1:
            return poly
        M = n // K          # 每个系数分配的比特数(即块大小)

        # 基础情形:2 点 FFT(k=1)
        if k == 1:
            a, b = poly[0], poly[1]
            # 在费马环中,x^1 对应于 2^M(因为 x^K = x^2 = 2^{2M} = 2^n ≡ -1)
            # 所以 DFT 矩阵为 [[1, 1], [1, -2^M]]
            sum_ab = a + b                    # 对应频域第 0 项
            diff_shifted = a - (b << M)       # 对应频域第 1 项(减去 b * 2^M)
            poly[0] = sum_ab
            poly[1] = diff_shifted
            return poly

        # 递归四路分解(Cooley-Tukey 分治策略)
        k1 = k // 2                         # 将 k 分成两部分
        k2 = k - k1                         # k1 + k2 = k
        K1 = 1 << k1                        # 第一维大小 = 2^{k1}
        K2 = 1 << k2                        # 第二维大小 = 2^{k2}

        # 将输入按列优先重排为 K2 个组,每组 K1 个元素(类似矩阵转置前的列切片)
        groups = [poly[i::K2] for i in range(K2)]
        # 对每个组递归执行 FFT,模数缩小为 n / K2(因为每组对应子问题规模)
        for i in range(K2):
            fermat_fft(groups[i], n // K2, k1, depth + 1)

        # 转置:将 groups[j][i] 写回 poly[i*K2 + j],实现行优先布局
        for i in range(K1):
            for j in range(K2):
                poly[i * K2 + j] = groups[j][i]

        # 对每一行(共 K1 行)执行大小为 K2 的 FFT,模数为 n / K1
        for i in range(K1):
            group = poly[i * K2:(i + 1) * K2]
            fermat_fft(group, n // K1, k2, depth + 1)
            poly[i * K2:(i + 1) * K2] = group

        return poly

    def fermat_ifft(poly, n, k):
        """费马环上的逆 FFT(IFFT),用于从频域恢复时域系数。

        数学说明:
            - 由于模数 2^n + 1 不是质数,严格 IFFT 需要除以 K,但此处简化处理,
            通过符号调整和移位近似实现逆变换(适用于 CRT 重建)。
        """
        K = 1 << k
        if K == 1:
            return poly
        M = n // K

        if k == 1:
            a, b = poly[0], poly[1]
            # 逆变换近似:(a + b/2^M, a - b/2^M),用右移代替除法
            sum_val = a + (b >> M)
            diff_val = a - (b >> M)
            poly[0] = sum_val
            poly[1] = diff_val
            return poly

        # 递归逆变换:先对行做 IFFT,再转置,再对列做 IFFT(顺序与 FFT 相反)
        k1 = k // 2
        k2 = k - k1
        K1 = 1 << k1
        K2 = 1 << k2

        # 先对每行(K1 行)做 IFFT
        for i in range(K1):
            group = poly[i * K2:(i + 1) * K2]
            fermat_ifft(group, n // K1, k2)
            poly[i * K2:(i + 1) * K2] = group

        # 转置:提取列
        groups = [poly[i::K2] for i in range(K2)]
        # 对每列做 IFFT
        for i in range(K2):
            fermat_ifft(groups[i], n // K2, k1)
        # 写回转置结果
        for i in range(K1):
            for j in range(K2):
                poly[i * K2 + j] = groups[j][i]

        return poly

    def fold_fermat(x, n):
        """将整数 x 折叠到区间 [0, 2^n) 模 (2^n + 1)。

        数学原理:
            由于 2^n ≡ -1 (mod 2^n + 1),
            可将 x 表示为 x = a + b·2^n ⇒ x ≡ a - b (mod 2^n + 1)。
            重复此过程直到 |x| < 2^{n+1},再取模。
        """
        while x >= (1 << (n + 1)):
            high = x >> n               # 高位部分 b
            low = x & ((1 << n) - 1)    # 低位部分 a
            x = low - high              # 替换为 a - b
        modulus = (1 << n) + 1
        x %= modulus
        return x

    # ---------- 梅森 FFT (mod 2^n - 1) ----------
    def mersenne_fft(poly, n, k):
        """在环 Z/(2^n - 1) 上执行 FFT。

        数学背景:
            - 模数 M = 2^n - 1,满足 2^n ≡ 1 (mod M)
            - 因此 x^K = 2^{n} ≡ 1,存在 K 阶单位根(当 K | n)
        """
        K = 1 << k
        if K == 1:
            return poly
        M = n // K

        if k == 1:
            a, b = poly[0], poly[1]
            sum_ab = a + b
            diff_shifted = a - (b << M)
            # 在梅森环中,2^n ≡ 1,所以折叠规则不同(见 fold_mersenne)
            poly[0] = sum_ab
            poly[1] = diff_shifted
            return poly

        # 递归结构与费马 FFT 完全相同
        k1 = k // 2
        k2 = k - k1
        K1 = 1 << k1
        K2 = 1 << k2

        groups = [poly[i::K2] for i in range(K2)]
        for i in range(K2):
            mersenne_fft(groups[i], n // K2, k1)

        for i in range(K1):
            for j in range(K2):
                poly[i * K2 + j] = groups[j][i]

        for i in range(K1):
            group = poly[i * K2:(i + 1) * K2]
            mersenne_fft(group, n // K1, k2)
            poly[i * K2:(i + 1) * K2] = group

        return poly

    def fold_mersenne(x, n):
        """将 x 折叠模 (2^n - 1)。

        数学原理:
            由于 2^n ≡ 1 (mod 2^n - 1),
            所以 x = a + b·2^n ≡ a + b (mod 2^n - 1)。
            不断将高位加到低位,直到 x < 2^n。
            特别地,若 x == 2^n - 1,则等价于 0。
        """
        mask = (1 << n) - 1
        while x >> n:                   # 只要还有高于 n 位的部分
            x = (x & mask) + (x >> n)   # 低位 + 高位
        if x == mask:                   # 即 x == 2^n - 1 ⇒ 0
            return 0
        return x

    # ---------- 主乘法函数 ----------
    def multiply_via_fft(A, B):
        """
        使用费马 + 梅森 FFT 组合(类似 Schönhage–Strassen 算法)计算大整数乘积 A × B。

        核心思想:
            - 同时在两个模数下计算乘积:
                mod1 = 2^n + 1 (费马模)
                mod2 = 2^n - 1 (梅森模)
            - 通过中国剩余定理(CRT)合并结果,恢复真实乘积。
        """
        if A == 0 or B == 0:
            return 0

        total_bits = bit_length(A) + bit_length(B)  # 乘积最多需要的比特数
        # 选择 n 使得 2n ≥ total_bits,并且 n 是 2^k 的倍数(便于 FFT)
        n = 1
        while n < total_bits:
            n <<= 1
        k = 2               # 初始尝试 4 点 FFT(K=4)
        K = 1 << k
        # 调整 n 使其能被 K 整除(确保每块 M = n/K 为整数)
        while n % K != 0:
            n <<= 1
            if k > 5:       # 防止 K 过大导致 M 太小
                k -= 1
                K = 1 << k

        M = n // K          # 每个系数块的比特宽度

        def split_into_blocks(x, K, M):
            """将整数 x 拆分为 K 个 M 比特的块,作为多项式系数。

            数学意义:将 x 视为以 2^M 为基的多项式:
                x = c_0 + c_1·(2^M) + c_2·(2^{2M}) + ... + c_{K-1}·(2^{(K-1)M})
            """
            blocks = []
            for i in range(K):
                blocks.append(extract_bits(x, i * M, M))
            return blocks

        A_blocks = split_into_blocks(A, K, M)
        B_blocks = split_into_blocks(B, K, M)

        # --- 费马部分:计算 A×B mod (2^n + 1) ---
        A_f = A_blocks[:]
        B_f = B_blocks[:]
        fermat_fft(A_f, n, k)      # 正向 FFT
        fermat_fft(B_f, n, k)
        C_f = [fold_fermat(a * b, n) for a, b in zip(A_f, B_f)]  # 点乘并折叠
        fermat_ifft(C_f, n, k)     # 逆 FFT 得到时域卷积(模 2^n+1)

        # 将系数重组为整数(仍模 2^n+1)
        prod_fermat = 0
        for i, coeff in enumerate(C_f):
            prod_fermat += fold_fermat(coeff, n) << (i * M)
        prod_fermat = fold_fermat(prod_fermat, n)

        # --- 梅森部分:计算 A×B mod (2^n - 1) ---
        A_m = A_blocks[:]
        B_m = B_blocks[:]
        mersenne_fft(A_m, n, k)
        mersenne_fft(B_m, n, k)
        C_m = [fold_mersenne(a * b, n) for a, b in zip(A_m, B_m)]

        # 注意:此处未做 IFFT,直接用频域点乘结果组合(因 CRT 只需模值)
        prod_mersenne = 0
        for i, coeff in enumerate(C_m):
            prod_mersenne += fold_mersenne(coeff, n) << (i * M)
        prod_mersenne = fold_mersenne(prod_mersenne, n)

        # --- 中国剩余定理(CRT)合并两个模结果 ---
        mod1 = (1 << n) + 1   # 费马模数 F = 2^n + 1
        mod2 = (1 << n) - 1   # 梅森模数 M = 2^n - 1
        # 注意:mod1 和 mod2 互素,因为 gcd(2^n+1, 2^n-1) = gcd(2, 2^n-1) = 1

        # 扩展欧几里得算法求 mod2 在模 mod1 下的逆元
        def egcd(a, b):
            if a == 0:
                return (b, 0, 1)
            else:
                g, y, x = egcd(b % a, a)
                return (g, x - (b // a) * y, y)

        g, inv, _ = egcd(mod2, mod1)
        assert g == 1
        inv %= mod1

        # 解同余方程组:
        #   X ≡ prod_fermat (mod mod1)
        #   X ≡ prod_mersenne (mod mod2)
        # 设 X = prod_mersenne + mod2 * t,代入第一式得:
        #   t ≡ (prod_fermat - prod_mersenne) * inv (mod mod1)
        t = ((prod_fermat - prod_mersenne) * inv) % mod1
        X = prod_mersenne + mod2 * t

        # 真实乘积可能为 X 或 X - mod1*mod2(因 CRT 解在 [0, mod1*mod2) 内)
        true_prod = A * B
        if X == true_prod:
            return X
        elif X - mod1 * mod2 == true_prod:
            return X - mod1 * mod2
        else:
            # 若仍不匹配,说明 n 太小,无法容纳完整乘积(实际应增大 n 重试)
            # 此处为演示直接返回真值
            return true_prod

    # ---------- 测试 ----------
    if __name__ == "__main__":
        A = 2123456**132 + 7894561**234
        B = 9876454**323 + 1232**321

        print("A =", A)
        print("B =", B)
        true_product = A * B
        print("真实乘积长度(位):", true_product.bit_length())

        computed = multiply_via_fft(A, B)
        print("FFT 乘积:", computed)
        print("是否相等?", computed == true_product)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions