You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

8.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

P3372 【模板】线段树 1

题目描述

如题,已知一个数列,你需要进行下面两种操作:

  1. 将某区间每一个数加上 k
  2. 求出某区间每一个数的和。

输入格式

第一行包含两个整数 n, m,分别表示该数列数字的个数和操作的总个数。

第二行包含 n 个用空格分隔的整数,其中第 i 个数字表示数列第 i 项的初始值。

接下来 m 行每行包含 34 个整数,表示一个操作,具体如下:

  1. 1 x y k:将区间 [x, y] 内每个数加上 k
  2. 2 x y:输出区间 [x, y] 内每个数的和。

输出格式

输出包含若干行整数,即为所有操作 2 的结果。

样例 #1

样例输入 #1

5 5
1 5 4 2 3
2 2 4
1 2 3 2
2 3 4
1 1 5 1
2 1 4

样例输出 #1

11
8
20

提示

对于 30\% 的数据:n \le 8m \le 10
对于 70\% 的数据:n \le {10}^3m \le {10}^4
对于 100\% 的数据:1 \le n, m \le {10}^5

保证任意时刻数列中所有元素的绝对值之和 \le {10}^{18}

【样例解释】

二、线段树解法

#include <iostream>
using namespace std;
const int N = 100010;
int n, q;

// 线段树模板
#define int long long
#define ls u << 1
#define rs u << 1 | 1
#define mid ((l + r) >> 1)
struct Node {
    int l, r;
    int sum, add; // 区间总和,累加懒标记
} tr[N << 2];

// 更新统计信息
void pushup(int u) {
    tr[u].sum = tr[ls].sum + tr[rs].sum;
}

void pushdown(int u) {
    if (tr[u].add == 0) return;
    tr[ls].add += tr[u].add;
    tr[rs].add += tr[u].add;
    tr[ls].sum += (tr[ls].r - tr[ls].l + 1) * tr[u].add;
    tr[rs].sum += (tr[rs].r - tr[rs].l + 1) * tr[u].add;
    tr[u].add = 0; // 清除懒标记
}

// 构建
void build(int u, int l, int r) {
    tr[u].l = l, tr[u].r = r; // 标记范围
    if (l == r) {             // 叶子
        cin >> tr[u].sum;     // 区间内只有一个元素l(r),区间和为read(),不需要记录向下的传递tag
        return;
    }
    build(ls, l, mid), build(rs, mid + 1, r); // 左右儿子构建
    pushup(u);                                // 通过左右儿子构建后,向祖先节点反馈统计信息变化
}

// 区间所有元素加上v
void modify(int u, int L, int R, int v) {
    int l = tr[u].l, r = tr[u].r;
    if (l >= L && r <= R) { // 如果完整被覆盖
        tr[u].sum += (r - l + 1) * v;
        tr[u].add += v;
        return;
    }
    if (l > R || r < L) return;               // 如果没有交集
    pushdown(u);                              // 下传懒标记
    modify(ls, L, R, v), modify(rs, L, R, v); // 修改左,修改右
    pushup(u);                                // 向上汇报统计信息
}

// 查询
int query(int u, int L, int R) {
    int l = tr[u].l, r = tr[u].r;
    if (l >= L && r <= R) return tr[u].sum;   // 如果完整被覆盖
    if (l > R || r < L) return 0;             // 如果没有交集
    pushdown(u);                              // 下传懒标记
    return query(ls, L, R) + query(rs, L, R); // 查询左+查询右
}

signed main() {
// 文件输入输出
#ifndef ONLINE_JUDGE
    freopen("P3372.in", "r", stdin);
#endif
    // 加快读入
    ios::sync_with_stdio(false), cin.tie(0);
    cin >> n >> q;

    // 构建线段树
    build(1, 1, n);

    while (q--) {
        int op, l, r, v;
        cin >> op >> l >> r;
        if (op == 1)
            cin >> v, modify(1, l, r, v);
        else
            printf("%lld\n", query(1, l, r));
    }
    return 0;
}

:加法的懒标记可以叠加,一般初始化为0

三、动态开点线段树解法

#include <bits/stdc++.h>

using namespace std;
const int N = 5e5 + 10;

// 动态开点线段树
#define int long long
#define ls tr[u].l
#define rs tr[u].r
#define mid ((l + r) >> 1)
struct Node {
    int l, r;
    int sum, add;
} tr[N << 1];

int root, idx;
// 汇总统计信息
void pushup(int u) {
    tr[u].sum = tr[ls].sum + tr[rs].sum;
}
// 创建节点:节点号分配,懒标记初始化
void build(int &u) {
    if (u) return;
    u = ++idx;
    // tr[u].add = 0;
}

void pushdown(int &u, int l, int r) {
    if (tr[u].add == 0) return; // 如果没有累加懒标记,返回
    build(ls);                  // 左儿子创建
    build(rs);                  // 右儿子创建
    // 懒标记下传
    tr[ls].sum += tr[u].add * (mid - l + 1); // 区间和增加=懒标记 乘以 区间长度
    tr[rs].sum += tr[u].add * (r - mid);
    tr[ls].add += tr[u].add; // 加法的懒标记可以叠加
    tr[rs].add += tr[u].add;
    // 清除懒标记
    tr[u].add = 0;
}

// 区间修改
void modify(int &u, int l, int r, int L, int R, int v) {
    build(u); // 动态开点

    if (l >= L && r <= R) {           // 如果区间被完整覆盖
        tr[u].add += v;               // 加法的懒标记可以叠加
        tr[u].sum += v * (r - l + 1); // 区间和增加=懒标记 乘以 区间长度
        return;
    }
    if (l > R || r < L) return; // 如果没有交集

    // 下传懒标记
    pushdown(u, l, r);
    // 分裂
    modify(ls, l, mid, L, R, v), modify(rs, mid + 1, r, L, R, v);
    // 汇总
    pushup(u);
}

// 区间查询
int query(int u, int l, int r, int L, int R) {
    if (l >= L && r <= R) return tr[u].sum; // 如果完整命中,返回我的全部
    if (l > R || r < L) return 0;           // 如果与我无关,返回0
    pushdown(u, l, r);
    return query(ls, l, mid, L, R) + query(rs, mid + 1, r, L, R);
}
/*
参考答案:
11
8
20
*/
signed main() {
#ifndef ONLINE_JUDGE
    freopen("P3372.in", "r", stdin);
#endif
    // 加快读入
    ios::sync_with_stdio(false), cin.tie(0);
    int n, m;
    cin >> n >> m; // n个节点m次操作
    for (int i = 1; i <= n; i++) {
        int x;
        cin >> x;
        modify(root, 1, n, i, i, x); // 单点修改,赋初值
    }

    while (m--) {
        int op, l, r;
        cin >> op >> l >> r;
        if (op == 1) {
            int x;
            cin >> x;
            modify(root, 1, n, l, r, x); //[l,r]区间修改为x
        } else
            cout << query(root, 1, n, l, r) << endl; // 区间sum和
    }
    return 0;
}

四、树状数组实现【不推荐】

区间修改和区间查询,正解还是线段树,不应该是树状数组+推公式,非得要做的话,也可以推导一下:

区间修改,单点查询

如果是区间修改,单点查询。只需用树状数组维护一个差分数组b,假设查询位置x,那么\displaystyle \sum_{i=1}^{x}b_i就是x位置上的变化后的值。

区间修改+区间和查询

考虑引入区间查询。首先最暴力想,假设查询[1,r]。那么[1,r]的答案=\displaystyle \sum_{i=1}^{r}\sum_{j=1}^{i}b_j 不妨举个特例,更直观些。假设查询[1, 4]。那么ans=(b_1)+(b_1+b_2)+(b_1+b_2+b_3)+(b_1+b_2+b_3+b_4)=4b_1+3b_2+2b_3+1b_4

换成查询[1, r]。那么\displaystyle ans=(r+1-1)b_1+(r+1-2)b_2+(r+1-3)b_3+…+(r+1-r)b_r = (r+1)\sum_{i=1}^{r}b_i-\sum_{i=1}^{r}i*b_i

显然第一项用树状数组tr1维护b数组可求出,第二项求不出。令c=i*b[i],新开一个树状数组tr2维护c就行了。

实现代码

#include <bits/stdc++.h>
using namespace std;
const int N = 1000010;
typedef long long LL;

int n, m; // n个元素m次操作
int a[N]; // 原始数组

LL tr1[N], tr2[N]; // ① 保存基底数组为原数组差分数组的树状数组 ② i*b[i]的前缀和数组

// 树状数组模板
int lowbit(int x) {
    return x & -x;
}

void add(int x, int c) {
    for (int i = x; i < N; i += lowbit(i)) tr1[i] += c, tr2[i] += x * c;
}

LL sum(int x) {
    LL res = 0;
    for (int i = x; i; i -= lowbit(i)) res += (x + 1) * tr1[i] - tr2[i];
    return res;
}

int main() {
    scanf("%d %d", &n, &m);

    int x, y, d, op;
    for (int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
        add(i, a[i] - a[i - 1]); // 保存基底是差分数组的树状数组
    }

    while (m--) {
        scanf("%d %d %d", &op, &x, &y);
        if (op == 1) {
            scanf("%d", &d);
            add(x, d), add(y + 1, -d); // 维护差分
        } else                         // 查询
            printf("%lld\n", sum(y) - sum(x - 1));
    }
    return 0;
}