## $AcWing$ $1221$. 四平方和 + 自定义排序(重载<)+二分 [题目传送门](https://www.acwing.com/problem/content/description/1223/) ### 一、题目大意 四平方和定理,又称为 **拉格朗日定理**: 每个正整数都可以表示为至多 $4$ 个正整数的平方和。 如果把 $0$ 包括进去,就正好可以表示为 $4$ 个数的平方和。 比如: $5=0^2+0^2+1^2+2^2$ $7=1^2+1^2+1^2+2^2$ 对于一个给定的正整数,可能存在多种平方和的表示法。 要求你对 $4$ 个数排序: $0 \leqslant a \leqslant b \leqslant c \leqslant d$ 并对所有的可能表示法按 $a,b,c,d$ 为联合主键升序排列,最后输出第一个表示法。 **输入格式** 输入一个正整数 $N$。 **输出格式** 输出$4$个非负整数,按 **从小到大** 排序,中间用空格分开。 **数据范围** $0 using namespace std; //通过了 10/11个数据 int main() { int n; cin >> n; for (int a = 0; a * a <= n; a++) for (int b = a; a * a + b * b <= n; b++) for (int c = b; a * a + b * b + c * c <= n; c++) { int t = n - a * a - b * b - c * c; int d = sqrt(t); if (d * d == t) { printf("%d %d %d %d\n", a, b, c, d); return 0; } } return 0; } ``` 暴力 $O(N^3)$ 一般对于$C++$里面的代码,$1s$也就运行$10^8$,这个题的最大取到$10^6$,开方$10$的$3$次方,$N$的$3$次方,大概$10^9$的运算次数,肯定会超时。 ### 三、二分作法 枚举$c$和$d$,将$c^2+d^2$存至数组中,再枚举$a$和$b$,查找$n−a^2−b^2$是否在数组中出现过。时间复杂度:$O(n^2logn^2)$。 ```c++ #include using namespace std; const int N = 5e6 + 10; struct Node { int c, d, sum; bool operator<(const Node &t) const { //对所有的可能表示法按 a,b,c,d 为联合主键升序排列,最后输出第一个表示法。 //因为这个第一个表示法,使得我们首先需要对sum进行升序排列,如果sum值一样,就需要对c进行升序排列.至于说d,其实无所谓了,因为本题中不可能存在 // sum一样,c也一样的场景,但也许有的题不行,需要写全排序办法 if (sum != t.sum) return sum < t.sum; if (c != t.c) return c < t.c; return d < t.d; } } f[N]; int n; int idx; int main() { cin >> n; //枚举c^2+d^2 for (int c = 0; c * c <= n; c++) for (int d = c; c * c + d * d <= n; d++) f[idx++] = {c, d, c * c + d * d}; //结构体排序 sort(f, f + idx); //枚举a^2+b^2 for (int a = 0; a * a <= n; a++) { for (int b = a; a * a + b * b <= n; b++) { int t = n - a * a - b * b; int p = lower_bound(f, f + idx, Node{0, 0, t}) - f; /* 结构体+lower_bound需要注意的事项: 一、结构体中二分查找,需要封装成无用项置0的结构体: (1) 把我们需要查找的数封装成一个结构体。然后才可以在结构体重进行查找。即使我们只需要针对某一维进行查找,也需要把整个结构体构造出来。 (2) 这里我只需要查找第一维,并且我对第一维进行了排序,只有有序数列才可以进行二分,然后在查找的时候,把其他维置零即可。但是必须要封装成一个结构体 二、最终的结果需要进行判断 (1)、可能找到,也可能找不到 (2)、因为lower_bound返回的是数组中第一个大于等于Node{0,0,t}的位置,有三种可能: 1. 命中,找到sum=t,现在p就是结果位置 2. 没有命中,返回的是sum>t的第一个位置,这不是我们想要的结果,需要判断一下这样的情况。 3. 没有命中,返回的是数组外边的一个空地,也是越界了也没有找到大于等于t的第一个位置,这不是我们想要的结果,需要判断一下这样的情况 */ if (p < idx && f[p].sum == t) { cout << a << ' ' << b << ' ' << f[p].c << ' ' << f[p].d << ' ' << endl; return 0; } } } return 0; } ``` ### 四、手写版本二分 ```c++ #include using namespace std; const int N = 5e6 + 10; struct Node { int c, d, sum; bool operator<(const Node &t) const { if (sum != t.sum) return sum < t.sum; if (c != t.c) return c < t.c; return d < t.d; } } f[N]; int n; int idx; int main() { scanf("%d", &n); // C预处理出 c^2+d^2 for (int c = 0; c * c <= n; c++) for (int d = c; c * c + d * d <= n; d++) f[idx++] = {c, d, c * c + d * d}; //结构体排序 sort(f, f + idx); //枚举a^2+b^2 for (int a = 0; a * a <= n; a++) { for (int b = a; a * a + b * b <= n; b++) { int t = n - a * a - b * b; //手写二分模板,左闭右开 // STL二分缺点: // 1、常数较大,速度慢 // 2、对于结构体二分,需要构造空的Struct,还需要有一些玄学的赋零操作,不推荐 //结论:全面采用手写二分办法,忘记STL的二分写法 int l = 0, r = idx; while (l < r) { int mid = (l + r) >> 1; if (f[mid].sum >= t) r = mid; else l = mid + 1; } if (f[l].sum == t) { cout << a << ' ' << b << ' ' << f[l].c << ' ' << f[l].d << ' ' << endl; exit(0); } } } return 0; } ``` ### 五、$STL$的$Hash$表 ```c++ #include using namespace std; typedef pair PII; #define x first #define y second // TLE // 通过了 8/11个数据 unordered_map _map; int main() { int n; cin >> n; for (int c = 0; c * c <= n; c++) for (int d = c; d * d + c * c <= n; d++) { int t = c * c + d * d; if (_map.count(t) == 0) _map[t] = {c, d}; } for (int a = 0; a * a <= n; a++) for (int b = a; b * b + a * a <= n; b++) { int t = n - a * a - b * b; if (_map.count(t)) { printf("%d %d %d %d", a, b, _map[t].x, _map[t].y); return 0; } } return 0; } ``` ### 六、桶排+二层循环+对边寻找 ```c++ #include #include #include #include /* 用桶计数的思路 */ // Accepted 134 ms using namespace std; const int N = 5e6 + 10; int n; int bucket[N]; int main() { scanf("%d", &n); memset(bucket, -1, sizeof bucket); for (int c = 0; c * c <= n / 2; c++) for (int d = c; c * c + d * d <= n; d++) { int t = c * c + d * d; if (t > 5e6) continue; if (bucket[t] == -1) bucket[t] = c; } for (int a = 0; a * a <= n / 4; a++) { for (int b = a; a * a + b * b <= n / 3; b++) { int t = n - a * a - b * b; int c = bucket[t]; if (bucket[t] == -1 || t > 5e6) continue; int d = sqrt(t - c * c); printf("%d %d %d %d\n", a, b, c, d); exit(0); } } return 0; } ```