# Segment Tree Primer Plus
本来在学生成函数,某队友突然发了一道题,只好学学了...
但学了几天也不是很懂,考虑到较难且应用应该不广(?),就草草结束了
参考:
[https://oi-wiki.org/ds/seg-beats/](https://oi-wiki.org/ds/seg-beats/)
[https://www.luogu.com.cn/blog/Sqrtree/solution-p6242](https://www.luogu.com.cn/blog/Sqrtree/solution-p6242)
[ppt](https://pan.baidu.com/s/1PhprDI3UcJ4rUEKZ7E_DgA?pwd=ikun)
提取码:ikun
>HDU5306 Gorgeous Sequence
[https://vjudge.net/problem/HDU-5306](https://vjudge.net/problem/HDU-5306)
区间 $[l,r]$ 中的所有数变成 $min(A_i,x)$
当我们在更新区间最值时分成三种情况。
1. 如果小于等于要求改的最值,就不需要做处理。
2. 如果大于次小值,那么我们就可以根据其与最大值的差值和最大值的个数就可以O(1)的更新答案了,同时打上标记。
3. 如果小于等于次小值,再向左右儿子递归。
有了这三条规则,这样才能保证lazy标记下放时直接更新子区间最大值和区间合的正确性,因为要更改的值大于次大值,所以不可能更小,与父亲区间更新的操作相同。
用势能分析的方法可以证明时间复杂度为 $O(mlogn)$。
```c++
#define linf 0x3f3f3f3f3f3f3f3f
#define N 1000010
int n, m;
ll mx[N << 2], submx[N << 2], cmx[N << 2], lzmn[N << 2], val[N << 2];
inline void pu(int p) {
val[p] = val[ls] + val[rs];
if (mx[ls] == mx[rs]) {
mx[p] = mx[rs];
submx[p] = max(submx[ls], submx[rs]);
cmx[p] = cmx[ls] + cmx[rs];
} else if (mx[ls] > mx[rs]) {
mx[p] = mx[ls];
submx[p] = max(submx[ls], mx[rs]);
cmx[p] = cmx[ls];
} else {
mx[p] = mx[rs];
submx[p] = max(mx[ls], submx[rs]);
cmx[p] = cmx[rs];
}
}
void build(int p, int l, int r) {
mx[p] = submx[p] = lzmn[p] = -linf;
if (l == r) {
val[p] = read();
mx[p] = val[p];
cmx[p] = 1;
return;
}
build(ls, l, mid), build(rs, mid + 1, r);
pu(p);
}
inline void pd(int p, int l, int r) {
if (lzmn[p] != -linf) {
if (mx[ls] > lzmn[p]) {
val[ls] += cmx[ls] * (lzmn[p] - mx[ls]);
lzmn[ls] = mx[ls] = lzmn[p];
}
if (mx[rs] > lzmn[p]) {
val[rs] += cmx[rs] * (lzmn[p] - mx[rs]);
lzmn[rs] = mx[rs] = lzmn[p];
}
lzmn[p] = -linf;
}
}
void updtomn(int p, int l, int r, int x, int y, int k) {
if (mx[p] <= k) return;
if (x <= l && r <= y && submx[p] < k) {
val[p] += cmx[p] * (k - mx[p]);
lzmn[p] = k;
mx[p] = k;
return;
}
pd(p, l, r);
if (x <= mid) updtomn(ls, l, mid, x, y, k);
if (mid < y) updtomn(rs, mid + 1, r, x, y, k);
pu(p);
}
ll inqsum(int p, int l, int r, int x, int y) {
if (x <= l && r <= y) {
return val[p];
}
pd(p, l, r);
ll res = 0;
if (x <= mid) res += inqsum(ls, l, mid, x, y);
if (mid < y) res += inqsum(rs, mid + 1, r, x, y);
return res;
}
ll inqmx(int p, int l, int r, int x, int y) {
if (x <= l && r <= y) {
return mx[p];
}
pd(p, l, r);
ll res = -linf;
if (x <= mid) res = max(res, inqmx(ls, l, mid, x, y));
if (mid < y) res = max(res, inqmx(rs, mid + 1, r, x, y));
return res;
}
signed main() {
int T;
T = read();
while (T--) {
n = read(), m = read();
build(1, 1, n);
for (int i = 1, op, x, y, z; i <= m; ++i) {
op = read(), x = read(), y = read();
if (op == 0) {
z = read();
updtomn(1, 1, n, x, y, z);
} else if (op == 1) {
printf("%lld\n", inqmx(1, 1, n, x, y));
} else {
printf("%lld\n", inqsum(1, 1, n, x, y));
}
}
}
}
```
>BZOJ4695 最假女选手
长度为 n 的序列,支持区间加 x,区间对 x 取 max,区间对 x 取 min,求区间和,求区间最大值,求区间最小值。
```c++
//代码是P4560的
struct node {
int mx, smx, cmx, lzmn, mn, smn, cmn, lzmx, val, lzp;
} tr[N << 2];
inline int read() {
int x = 0;
char c = getchar();
while (c < '0' || c > '9') c = getchar();
while ('0' <= c && c <= '9') {
x = (x << 3) + (x << 1) + (c ^ 48);
c = getchar();
}
return x;
}
inline void pu(int p) {
tr[p].val = tr[ls].val + tr[rs].val;
if (tr[ls].mx == tr[rs].mx) {
tr[p].mx = tr[ls].mx, tr[p].cmx = tr[ls].cmx + tr[rs].cmx;
tr[p].smx = max(tr[ls].smx, tr[rs].smx);
} else if (tr[ls].mx > tr[rs].mx) {
tr[p].mx = tr[ls].mx, tr[p].cmx = tr[ls].cmx;
tr[p].smx = max(tr[ls].smx, tr[rs].mx);
} else {
tr[p].mx = tr[rs].mx, tr[p].cmx = tr[rs].cmx;
tr[p].smx = max(tr[ls].mx, tr[rs].smx);
}
if (tr[ls].mn == tr[rs].mn) {
tr[p].mn = tr[ls].mn, tr[p].cmn = tr[ls].cmn + tr[rs].cmn;
tr[p].smn = min(tr[ls].smn, tr[rs].smn);
} else if (tr[ls].mn < tr[rs].mn) {
tr[p].mn = tr[ls].mn, tr[p].cmn = tr[ls].cmn;
tr[p].smn = min(tr[ls].smn, tr[rs].mn);
} else {
tr[p].mn = tr[rs].mn, tr[p].cmn = tr[rs].cmn;
tr[p].smn = min(tr[ls].mn, tr[rs].smn);
}
}
inline void padd(int p, int l, int r, ll k) {
tr[p].val += k * (r - l + 1);
tr[p].mx += k, tr[p].mn += k;
if (tr[p].smx != -inf) tr[p].smx += k;
if (tr[p].smx != inf) tr[p].smn += k;
if (tr[p].lzmx != -inf) tr[p].lzmx += k;
if (tr[p].lzmn != inf) tr[p].lzmn += k;
tr[p].lzp += k;
}
inline void pmn(int p, int k) {
if (tr[p].mx <= k) return;
tr[p].val += tr[p].cmx * (k - tr[p].mx);
if (tr[p].smn == tr[p].mx) tr[p].smn = k;
if (tr[p].mn == tr[p].mx) tr[p].mn = k;
if (tr[p].lzmx > k) tr[p].lzmx = k;
tr[p].mx = k, tr[p].lzmn = k;
}
inline void pmx(int p, int k) {
if (tr[p].mn >= k) return;
tr[p].val += tr[p].cmn * (k - tr[p].mn);
if (tr[p].smx == tr[p].mn) tr[p].smx = k;
if (tr[p].mx == tr[p].mn) tr[p].mx = k;
if (tr[p].lzmn < k) tr[p].lzmn = k;
tr[p].mn = k, tr[p].lzmx = k;
}
inline void pd(int p, int l, int r) {
if (tr[p].lzp) padd(ls, l, mid, tr[p].lzp), padd(rs, mid + 1, r, tr[p].lzp);
if (tr[p].lzmx != -inf) pmx(ls, tr[p].lzmx), pmx(rs, tr[p].lzmx);
if (tr[p].lzmn != inf) pmn(ls, tr[p].lzmn), pmn(rs, tr[p].lzmn);
tr[p].lzp = 0, tr[p].lzmx = -inf, tr[p].lzmn = inf;
}
inline void build(int p, int l, int r) {
tr[p].lzmn = inf, tr[p].lzmx = -inf, tr[p].lzp = 0;
if (l == r) {
tr[p].val = tr[p].mx = tr[p].mn = 0;
tr[p].smx = -inf, tr[p].smn = inf;
tr[p].cmx = tr[p].cmn = 1;
return;
}
build(ls, l, mid), build(rs, mid + 1, r);
pu(p);
}
inline void add(int p, int l, int r, int x, int y, ll k) {
if (x <= l && r <= y) return padd(p, l, r, k);
pd(p, l, r);
if (x <= mid) add(ls, l, mid, x, y, k);
if (mid < y) add(rs, mid + 1, r, x, y, k);
pu(p);
}
inline void tomn(int p, int l, int r, int x, int y, int k) {
if (tr[p].mx <= k) return;
if (x <= l && r <= y && tr[p].smx < k) return pmn(p, k);
pd(p, l, r);
if (x <= mid) tomn(ls, l, mid, x, y, k);
if (mid < y) tomn(rs, mid + 1, r, x, y, k);
pu(p);
}
inline void tomx(int p, int l, int r, int x, int y, int k) {
if (tr[p].mn >= k) return;
if (x <= l && r <= y && tr[p].smn > k) return pmx(p, k);
pd(p, l, r);
if (x <= mid) tomx(ls, l, mid, x, y, k);
if (mid < y) tomx(rs, mid + 1, r, x, y, k);
pu(p);
}
inline ll inqsum(int p, int l, int r, int x, int y) {
if (x <= l && r <= y) return tr[p].val;
ll res = 0;
if (x <= mid) res += inqsum(ls, l, mid, x, y);
if (mid < y) res += inqsum(rs, mid + 1, r, x, y);
return res;
}
inline ll inqmx(int p, int l, int r, int x, int y) {
if (x <= l && r <= y) return tr[p].mx;
ll res = -inf;
if (x <= mid) res = max(res, inqmx(ls, l, mid, x, y));
if (mid < y) res = max(res, inqmx(rs, mid + 1, r, x, y));
return res;
}
inline ll inqmn(int p, int l, int r, int x, int y) {
if (x <= l && r <= y) return tr[p].mn;
ll res = inf;
if (x <= mid) res = min(res, inqmn(ls, l, mid, x, y));
if (mid < y) res = min(res, inqmn(rs, mid + 1, r, x, y));
return res;
}
inline void print(int p, int l, int r) {
if (l == r) {
cout << tr[p].val << endl;
return;
}
pd(p, l, r);
print(ls, l, mid), print(rs, mid + 1, r);
}
signed main() {
n = read(), m = read();
build(1, 1, n);
for (int i = 1, op, x, y, k; i <= m; ++i) {
op = read(), x = read(), y = read(), k = read();
x++, y++;
if (op == 1) {
tomx(1, 1, n, x, y, k);
} else {
tomn(1, 1, n, x, y, k);
}
}
print(1, 1, n);
}
```
>P6242
给出一个长度为 $n$ 的数列 $A$,同时定义一个辅助数组 $B$,$B$ 开始与 $A$ 完全相同。接下来进行了 $m$ 次操作,操作有五种类型,按以下格式给出:
- `1 l r k`:对于所有的 $i\in[l,r]$,将 $A_i$ 加上 $k$($k$ 可以为负数)。
- `2 l r v`:对于所有的 $i\in[l,r]$,将 $A_i$ 变成 $\min(A_i,v)$。
- `3 l r`:求 $\sum_{i=l}^{r}A_i$。
- `4 l r`:对于所有的 $i\in[l,r]$,求 $A_i$ 的最大值。
- `5 l r`:对于所有的 $i\in[l,r]$,求 $B_i$ 的最大值。
在每一次操作后,我们都进行一次更新,让 $B_i\gets\max(B_i,A_i)$。
以下是需要维护的 tag:
- mxp:该区间最大值的懒标记。
- umxp:该区间非最大值的懒标记。
- mxmxp:该区间最大值的懒标记的最大值。
- mxumxp:该区间非最大的值的懒标记的最大值。
具体地,mxmxp 是在未下传懒标记前最大的 mxp 的值,mxumxp 是在未下传懒标记前最大的 umxp 的值。
```c++
struct node {
ll mx, smx/*submax*/, cmx/*cnt of max*/, hmx/*historical max*/, val;
ll mxp/*addtion of max*/, umxp/*addtion of not max*/, mxmxp/*max of addtion of max*/, mxumxp/*max of addtion of not max*/;
} tr[N << 2];
inline int read() {
int x = 0, f = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while ('0' <= c && c <= '9') {
x = (x << 3) + (x << 1) + (c ^ 48);
c = getchar();
}
return x * f;
}
inline void pu(int p) {
tr[p].val = tr[ls].val + tr[rs].val;
tr[p].hmx = max(tr[ls].hmx, tr[rs].hmx);
if (tr[ls].mx == tr[rs].mx) {
tr[p].mx = tr[ls].mx, tr[p].cmx = tr[ls].cmx + tr[rs].cmx;
tr[p].smx = max(tr[ls].smx, tr[rs].smx);
} else if (tr[ls].mx > tr[rs].mx) {
tr[p].mx = tr[ls].mx, tr[p].cmx = tr[ls].cmx;
tr[p].smx = max(tr[ls].smx, tr[rs].mx);
} else {
tr[p].mx = tr[rs].mx, tr[p].cmx = tr[rs].cmx;
tr[p].smx = max(tr[ls].mx, tr[rs].smx);
}
}
inline void change(int p, int l, int r, int k1, int k2, int k3, int k4) {
tr[p].val += 1ll * k1 * tr[p].cmx + 1ll * k2 * (r - l + 1 - tr[p].cmx);
tr[p].hmx = max(tr[p].hmx, tr[p].mx + k3);
tr[p].mx += k1;
if (tr[p].smx != -inf) tr[p].smx += k2;
tr[p].mxmxp = max(tr[p].mxmxp, tr[p].mxp + k3);
tr[p].mxumxp = max(tr[p].mxumxp, tr[p].umxp + k4);
tr[p].mxp += k1, tr[p].umxp += k2;
}
inline void pd(int p, int l, int r) {
ll tmpmx = max(tr[ls].mx, tr[rs].mx);
if (tr[ls].mx == tmpmx)change(ls, l, mid, tr[p].mxp, tr[p].umxp, tr[p].mxmxp, tr[p].mxumxp);
else change(ls, l, mid, tr[p].umxp, tr[p].umxp, tr[p].mxumxp, tr[p].mxumxp);
if (tr[rs].mx == tmpmx) change(rs, mid + 1, r, tr[p].mxp, tr[p].umxp, tr[p].mxmxp, tr[p].mxumxp);
else change(rs, mid + 1, r, tr[p].umxp, tr[p].umxp, tr[p].mxumxp, tr[p].mxumxp);
tr[p].mxp = tr[p].umxp = tr[p].mxmxp = tr[p].mxumxp = 0;
}
inline void build(int p, int l, int r) {
if (l == r) {
tr[p].hmx = tr[p].val = tr[p].mx = read();
tr[p].smx = -inf;
tr[p].cmx = 1;
return;
}
build(ls, l, mid), build(rs, mid + 1, r);
pu(p);
}
inline void add(int p, int l, int r, int x, int y, ll k) {
if (x <= l && r <= y) {
tr[p].val += 1ll * k * tr[p].cmx + 1ll * k * (r - l + 1 - tr[p].cmx);
tr[p].mx += k;
tr[p].hmx = max(tr[p].hmx, tr[p].mx);
if (tr[p].smx != -inf)tr[p].smx += k;
tr[p].mxp += k, tr[p].umxp += k;
tr[p].mxmxp = max(tr[p].mxmxp, tr[p].mxp);
tr[p].mxumxp = max(tr[p].mxumxp, tr[p].umxp);
return;
}
pd(p, l, r);
if (x <= mid) add(ls, l, mid, x, y, k);
if (mid < y) add(rs, mid + 1, r, x, y, k);
pu(p);
}
inline void tomn(int p, int l, int r, int x, int y, ll k) {
if (tr[p].mx <= k) return;
if (x <= l && r <= y && tr[p].smx < k) {
ll kk = tr[p].mx - k;
tr[p].val -= 1ll * tr[p].cmx * kk;
tr[p].mx = k, tr[p].mxp -= kk;
return;
}
pd(p, l, r);
if (x <= mid) tomn(ls, l, mid, x, y, k);
if (mid < y) tomn(rs, mid + 1, r, x, y, k);
pu(p);
}
inline ll inqsum(int p, int l, int r, int x, int y) {
if (x <= l && r <= y) return tr[p].val;
pd(p, l, r);
ll res = 0;
if (x <= mid) res += inqsum(ls, l, mid, x, y);
if (mid < y) res += inqsum(rs, mid + 1, r, x, y);
return res;
}
inline ll inqmx(int p, int l, int r, int x, int y) {
if (x <= l && r <= y) return tr[p].mx;
pd(p, l, r);
ll res = -inf;
if (x <= mid) res = max(res, inqmx(ls, l, mid, x, y));
if (mid < y) res = max(res, inqmx(rs, mid + 1, r, x, y));
return res;
}
inline ll inqhmx(int p, int l, int r, int x, int y) {
if (x <= l && r <= y) return tr[p].hmx;
pd(p, l, r);
ll res = -inf;
if (x <= mid) res = max(res, inqhmx(ls, l, mid, x, y));
if (mid < y) res = max(res, inqhmx(rs, mid + 1, r, x, y));
return res;
}
inline void print(int p, int l, int r) {
if (l == r) {
cout << tr[p].hmx << " ";
return;
}
pd(p, l, r);
print(ls, l, mid), print(rs, mid + 1, r);
}
signed main() {
n = read(), m = read();
build(1, 1, n);
for (int i = 1, op, x, y, k; i <= m; ++i) {
op = read(), x = read(), y = read();
if (op == 1) {
k = read();
add(1, 1, n, x, y, k);
} else if (op == 2) {
k = read();
tomn(1, 1, n, x, y, k);
} else if (op == 3) {
printf("%lld\n", inqsum(1, 1, n, x, y));
} else if (op == 4) {
printf("%lld\n", inqmx(1, 1, n, x, y));
} else {
printf("%lld\n", inqhmx(1, 1, n, x, y));
}
}
}
```
>大秦酒店欢迎您
[https://vjudge.csgrandeur.cn/contest/558282#problem/J](https://vjudge.csgrandeur.cn/contest/558282#problem/J)
给定一个长度为 n 的颜色序列,不同的颜色用不同的正整数编号代替。序列中每个正整数编号 ai 代表一种颜色,满足 1 ≤ ai ≤ n。
下面称一个区间的颜色数为这个区间内不同的颜色数量。
有 q 个询问,每次询问给出一个区间 [l, r] ,你要输出 [l, r] 的所有子区间的颜色数之和。
为了防止答案过大,你只需要输出答案对 $2^32$ 取模的结果。
首先把询问按照右端点离线,设当前处理到的位置为x,那么设ai表示区间[i,x]的颜色数,bi表示区间[i,i],[i,i+1]…[i,x]的颜色数之和,容易发现对于右端点为x的询问,其答案就是bi的区间和。
容易发现bi就是ai的历史和,ai的维护就是通过记录“上一次出现同种颜色的位置pre” ,以及简单的区间加即可实现。
对于一个区间,我们要求当前区间和以及所有历史时刻之和
考虑维护 sum,sumh 分别表示当前和以及历史和
如果没有加法标记,我们可以直接存一个 tag,表示 sumh←sumh+tag×sum
然后我们发现存在加法标记的情况下要先下传加法标记,再下传 tag
考虑下传加法标记:
sum←sum+len×v
add←add+v
sumh←sumh+tag×sum
对于点 x 来说,假设原来有 v1,tag1,那么现在变为 v1+v,tag1
而根据定义,tag1 应该只和 v1 结合,所有加多了一部分,那么就再设标记 addh 表示 sumh 需要减去的值就好了.
下传标记的顺序是:add,addh,tag.
```c++
#define ul unsigned int
#define ls p<<1
#define rs p<<1|1
#define mid (l+r>>1)
int n, m;
int pre[N], co[N], last[N];
struct node {
ul a, b;
int lza, lzb, tag;
} tr[N << 2];
struct query {
int l, r, id;
ul ans;
} Q[N];
vector<int> adj[N];
inline void pu(int p) {
tr[p].a = tr[ls].a + tr[rs].a;
tr[p].b = tr[ls].b + tr[rs].b;
}
inline void paddh(int p, int l, int r, int k, int t) {
tr[p].lzb += k;
if (t) tr[p].b += k * (r - l + 1);
}
inline void padd(int p, int l, int r, int k) {
if (tr[p].tag) {
paddh(p, l, r, -k * tr[p].tag, 0);
}
tr[p].lza += k;
tr[p].a += (r - l + 1) * k;
}
inline void ptag(int p, int l, int r, int k) {
tr[p].tag += k;
tr[p].b += k * tr[p].a;
}
inline void pd(int p, int l, int r) {
if (tr[p].lza) {
padd(ls, l, mid, tr[p].lza);
padd(rs, mid + 1, r, tr[p].lza);
tr[p].lza = 0;
}
if (tr[p].lzb) {
paddh(ls, l, mid, tr[p].lzb, 1);
paddh(rs, mid + 1, r, tr[p].lzb, 1);
tr[p].lzb = 0;
}
if (tr[p].tag) {
ptag(ls, l, mid, tr[p].tag);
ptag(rs, mid + 1, r, tr[p].tag);
tr[p].tag = 0;
}
}
inline void upd(int p, int l, int r, int x, int y) {
if (x <= l && r <= y) {
padd(p, l, r, 1);
return;
}
pd(p, l, r);
if (x <= mid) upd(ls, l, mid, x, y);
if (mid < y) upd(rs, mid + 1, r, x, y);
pu(p);
}
inline ul inq(int p, int l, int r, int x, int y) {
if (x <= l && r <= y) {
return tr[p].b;
}
pd(p, l, r);
ul res = 0;
if (x <= mid) res += inq(ls, l, mid, x, y);
if (mid < y) res += inq(rs, mid + 1, r, x, y);
return res;
}
signed main() {
ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
cin >> n >> m;
for (int i = 1; i <= n; ++i) {
cin >> co[i];
if (last[co[i]]) {
pre[i] = last[co[i]];
}
last[co[i]] = i;
}
for (int i = 1; i <= m; ++i) {
cin >> Q[i].l >> Q[i].r;
adj[Q[i].r].emplace_back(i);
Q[i].id = i;
}
for (int i = 1; i <= n; ++i) {
upd(1, 1, n, pre[i] + 1, i);
ptag(1, 1, n, 1);
for (int j = 0; j < adj[i].size(); ++j) {
query &q = Q[adj[i][j]];
q.ans = inq(1, 1, n, q.l, i);
}
}
for (int i = 1; i <= m; ++i) cout << Q[i].ans << endl;
}
```
>U180387 CTSN loves segment tree
[U180387 CTSN loves segment tree](https://www.luogu.com.cn/problem/U180387)
求区间中 $A_i+B_i$ 的最大值
我们把区间中的 **位置** 分成四类:在 $A,B$ 中同是区间最大值的位置、在 $A$ 中是区间最大值在 $B$ 中不是的位置、在 $B$ 中是区间最大值在 $A$ 中不是的位置、在 $A,B$ 中都不是区间最大值的位置。对这四类数分别维护 **答案** 和 **标记** 即可。举个例子,我们维护 $C_{1\sim 4},M_{1\sim 4},A_{\max},B_{\max}$ 分别表示当前区间中四类数的个数,四类数的答案的最大值,$A$ 序列的最大值、$B$ 序列的最大值。然后合并信息该怎么合并就怎么合并了。