[洛谷]P1009 [NOIP1998 普及组] 阶乘之和

高精度加+高精度乘(优化码风.ver)

Posted by CloudingYu on January 13, 2025

题目描述

用高精度计算出 $S=1!+2!+3!+···+n!$ $($ $n$ $\le$ $50$ $)$ 。

其中 $!$ 表示阶乘,例如:$5!=5\times4\times3\times2\times1$ 。

整体思路

这道题的思路非常简单,就是一个很普通的循环嵌套,而难点在于高精度代码的书写,部分代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#include <iostream>
#include <cstring> //memset函数头文件
using namespace std;
int n;
int ans[1000], facto[1000];
void mul(int k)
{
    // 这里写高精度乘法
}
void plu()
{
    // 这里写高精度加法
}
int main()
{
    cin >> n;
    for (int i = 1; i <= n; i++)
    {
        memset(facto, 0, sizeof(facto)); // 清空数组
        facto[0] = 1;                    // 因为是乘法,初始值为1,位数也为1,
        facto[1] = 1;
        for (int j = 1; j <= i; j++) // 计算阶乘
            mul(j);
        plu();
    }
    for (int i = ans[0]; i >= 1; i++) // 将结果逆序输出
        cout << ans[i];
    cout << endl;
    return 0;
}

我们只需要将两个函数补充完整便可以轻松AC这道题

高精度存储

因为有些计算的位数可达几十位甚至几百位,所以需要使用高精度算法进行计算。

为了方便计算,在使用高精度算法时,通常会使用逆序输入,如:

$123$ 可将其存储在一个长度为4的数组a中,其中$\displaystyle a[1]=3$, $\displaystyle a[2]=2$ ,$\displaystyle a[3]=1$ ,并用 $\displaystyle a[0]=3$ 代表此数的位数是 $3$。

高精度加法

已知有两数 $\displaystyle a[MAXN]$ , $b[MAXN]$ ,求其和$\displaystyle c[MAXN]$ 。 先随便蒟个栗子,找一下规律:

1
2
3
4
5
        4 5 6
          9 3
    +___1_______
        5 4 9
                

可以看到,得到结果的第 $i$ 位,就等于两个加数的第 $i$ 位之和,加上 $i-1$ 位的进位值。 若得到的结果大于等于 $10$ ,则将第 $i+1$ 位的数 $+1$ ,同时自身取模 $10$ 。 则有以下程序:

1
2
3
4
5
6
7
8
9
10
11
for (int i = 1; i <= max(a[0], b[0]); i++)
{
    c[i] = a[i] + b[i];
    c[0] = i;       // 更新位数
    if (c[i] >= 10) // 若二者之和大于等于10,需进位
    {
        c[i] %= 10;   // 取个位数
        c[i + 1]++;   // 进位
        c[0] = i + 1; // 更新位数
    }
}

高精度乘法

相较于加法,高精度乘法更加复杂, 已知有两数 $\displaystyle a [MAXN]$ ,$\displaystyle b [MAXN]$ ,求其积 $\displaystyle c [MAXN]$ 。 还是先找个栗子:

1
2
3
4
5
6
7
        4 5 6
          9 3
    ×_________ 
      1 3 6 8
    4 1 0 4
  ______1_____
    4 2 4 0 8

也可以得出规律 $\displaystyle c[i+j-1]+=a[i]\times b[j]+x$ , 其中$1 \le i \le a[0]$, $\displaystyle 1\le j \le b[0]$ ,$\displaystyle x = c[i+j-1]/10 $ 。 则有以下程序:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
for (int i = 1; i <= a[0]; i++)
{
    x = 0;
    for (int j = 1; j <= b[0]; j++)
    {
        c[i + j - 1] += a[i] * b[j] + x;
        x = c[i + j - 1] / 10;
        c[i + j - 1] % 10;
    }
    c[i + b[0]] = x;
}
c[0] = a[0] + b[0];
while (c[c[0]] == 0) // 去除多余的0,得到精确的位数
    c[0]--;

接下来,只需要按照两种模版和题目要求依葫芦画瓢就行了。

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include <iostream>
#include <cstring> //memset函数头文件
using namespace std;
const int MAXN = 1000; // 数组开大点没关系
int n;
int ans[MAXN], facto[MAXN], ans_[MAXN], facto_[MAXN], k_[MAXN];
void mul(int k) // 与上面高精度乘法类似
{
    memset(facto_, 0, sizeof(facto_)); // 清空临时数组
    memset(k_, 0, sizeof(k_));
    while (k != 0) // 将k逆序存储
    {
        k_[++k_[0]] = k % 10;
        k /= 10;
    }
    for (int i = 1; i <= k_[0]; i++)
    {
        int x = 0;
        for (int j = 1; j <= facto[0]; j++)
        {
            facto_[i + j - 1] += k_[i] * facto[j] + x;
            x = facto_[i + j - 1] / 10;
            facto_[i + j - 1] %= 10;
        }
        facto_[i + facto[0]] = x;
    }
    facto_[0] = facto[0] + k_[0];
    while (facto_[facto_[0]] == 0)
        facto_[0]--;
    for (int i = 0; i <= facto_[0]; i++)
        facto[i] = facto_[i];
    return;
}
void plu() // 与上面高精度加法类似
{
    memset(ans_, 0, sizeof(ans_)); // 清空临时数组
    for (int i = 1; i <= max(ans[0], facto[0]); i++)
    {
        ans_[i] += ans[i] + facto[i];
        ans_[0] = i;
        if (ans_[i] >= 10)
        {
            ans_[i] %= 10;
            ans_[i + 1]++;
            ans_[0] = i + 1;
        }
    }
    for (int i = 0; i <= ans_[0]; i++)
        ans[i] = ans_[i];
    return;
}
int main()
{
    cin >> n;
    for (int i = 1; i <= n; i++)
    {
        memset(facto, 0, sizeof(facto)); // 清空数组
        facto[0] = 1;                    // 因为是乘法,初始值为1,位数也为1,
        facto[1] = 1;
        for (int j = 1; j <= i; j++) // 计算阶乘
            mul(j);
        plu();
    }
    for (int i = ans[0]; i >= 1; i--) // 将结果逆序输出
        cout << ans[i];
    cout << endl;
    return 0;
}

三年后的我

在大一系统学习了 《程序设计》 之后,我才意识到我高中写的这段代码有多么的丑陋且不具有普遍性,遂重构一份代码,也献给三年前的自己VAV

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include <stdio.h> //C++是什么,没学过
#include <stdlib.h>
const int MAXN = 1000;
struct big_digits
{
    int len;
    int digit[MAXN];
};

// 将int型转为big_digits型
struct big_digits *trans(int n)
{
    struct big_digits *ans = (struct big_digits *)malloc(sizeof(struct big_digits));
    ans->len = 0;
    memset(ans->digit, 0, sizeof(ans->digit));
    if (n == 0)
    {
        ans->len = 1;
        ans->digit[0] = 0;
    }
    else
    {
        while (n)
        {
            ans->digit[ans->len++] = n % 10;
            n /= 10;
        }
    }
    return ans;
}

// 返回a+b的结果
struct big_digits *plus(struct big_digits *a, struct big_digits *b)
{
    struct big_digits *ans = (struct big_digits *)malloc(sizeof(struct big_digits));
    memset(ans->digit, 0, sizeof(ans->digit));
    int carry = 0;
    for (int i = 0; i < MAXN; i++)
    {
        ans->digit[i] = a->digit[i] + b->digit[i] + carry;
        carry = ans->digit[i] / 10;
        ans->digit[i] %= 10;
    }
    ans->len = MAXN;
    while (ans->len > 1 && ans->digit[ans->len - 1] == 0)
        ans->len--;
    return ans;
}

// 返回a*b的结果
struct big_digits *mult(struct big_digits *a, struct big_digits *b)
{
    struct big_digits *ans = (struct big_digits *)malloc(sizeof(struct big_digits));
    memset(ans->digit, 0, sizeof(ans->digit));
    for (int i = 0; i < a->len; i++)
    {
        for (int j = 0; j < b->len; j++)
        {
            ans->digit[i + j] += a->digit[i] * b->digit[j];
            ans->digit[i + j + 1] += ans->digit[i + j] / 10;
            ans->digit[i + j] %= 10;
        }
    }
    ans->len = a->len + b->len;
    while (ans->len > 1 && ans->digit[ans->len - 1] == 0)
        ans->len--;
    return ans;
}

// 输出big_digits型
void prin(struct big_digits *a)
{
    for (int i = a->len - 1; i >= 0; i--)
    {
        printf("%d", a->digit[i]);
    }
    printf("\n");
}
int main()
{
    int n;
    scanf("%d", &n);
    struct big_digits *ans = (struct big_digits *)malloc(sizeof(struct big_digits));
    ans = trans(0);
    for (int i = 1; i <= n; i++)
    {
        struct big_digits *temp = (struct big_digits *)malloc(sizeof(struct big_digits));
        temp = trans(1);
        for (int j = 1; j <= i; j++)
            temp = mult(temp, trans(j));
        ans = plus(ans, temp);
    }
    prin(ans);
    return 0;
}