说实话这题写树剖 $LCT$ 什么的真的思想又不难又好实现的样子,但是我还是选择自虐选择了动态点分治

那就两种做法都稍微提一下:

树链剖分 / $LCT$

很容易可以发现一个换根操作只会对当前根在原树(根为 $1$ )上的祖先一条链造成影响,也就是将它们的子树变成除当前链方向其它与之相连的点集,那么用树剖跳,用线段树维护一下原树上从上面来的和从下面来的,再将所有涉及的节点合并,并且删去算重复的和不该算的即可(虽然我没实现但是这个思路应该是对的

动态点分治

首先有两篇博客:zzq租酥雨

因为我们要算 $\sum\limits_{i = 1}^n s_i^2$ ,可以先想一下 $\sum\limits_{i = 1}^n s_i$ 怎么算,因为每个点的贡献只会被祖先计算到,那么易知 $\sum\limits_{i = 1}^n s_i = \sum\limits_{i = 1}^n value_i * (depth_i + 1)$ ,这个直接用动态点分治维护三个变量 $sumo_i, sumt_i, sumfa_i$ (sumo -> $p$ 子节点权值之和, $sumt$ -> 子节点权值与距离的乘积到 $p$ 之和, $sumfa$ -> 子节点权值与距离的乘积到 $fa$ (点分树上)之和)即可得到

接下来有个很重要的结论(反正我是肯定想不到

​ - 不论根如何换, $\sum\limits_{i = 1}^n s_i (sum - s_i)$ 一定是一个定值(说实话一开始一直想着如何化简 $\sum\limits_{i = 1}^n s_i$ ,所以是真的没有想到可以通过构造定值的方法来解出 $\sum\limits_{i = 1}^n s_i^2$ )

先来意会一下这个结论:就是每一条边连接的两个点在他们的子树中各自选两个点让它们权值相乘,求总权值

那么就比较容易知道证明了:每条边的边权为所有对应路径经过这条边的两个点的权值和,求总边权,即求的是 $\sum\limits_{i = 1}^n \sum\limits_{j = 1}^n value_i * value_j * dist (i, j)$ ,所以有 $\sum\limits_{i = 1}^n s_i (sum - s_i) = \sum\limits_{i = 1}^n \sum\limits_{j = 1}^n value_i * value_j * dist (i, j)$

因为该式是个定值,所以求出 $\sum\limits_{i = 1}^n s_i$ 后直接解出 $Ans$ 就完成了查询操作

对于修改操作,若修改完后与原权值的差值为 $\Delta value$ ,那么 $\Delta total = \Delta value \sum\limits_{j = 1}^n value_j * dist (p, j)$ ( $p$ 为修改点)

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>

using namespace std;

typedef long long LL;

const int MAXN = 2e05 + 10;
const int MAXM = 2e05 + 10;

const int INF = 0x7fffffff;

struct LinkedForwardStar {
int to;

int next;
} ;

LinkedForwardStar Link[MAXM << 1];
int Head[MAXN]= {0};
int size = 0;

void Insert (int u, int v) {
Link[++ size].to = v;
Link[size].next = Head[u];

Head[u] = size;
}

int N, Q;
int value[MAXN];

LL Deep[MAXN]= {0};
int Dfn[MAXN]= {0};
int val[MAXN << 1]= {0}, belong[MAXN << 1]= {0};
int dfsord = 0;
LL sum = 0;
LL subtree[MAXN]= {0};
LL s1 = 0;
void DFS (int root, int fa) {
Dfn[root] = ++ dfsord;
val[dfsord] = Deep[root], belong[dfsord] = root;
sum += value[root];
subtree[root] = value[root];
for (int i = Head[root]; i; i = Link[i].next) {
int v = Link[i].to;
if (v == fa)
continue;
Deep[v] = Deep[root] + 1;
DFS (v, root);
val[++ dfsord] = Deep[root], belong[dfsord] = root;
subtree[root] += subtree[v];
}
}
pair<int, int> ST[MAXN << 1][25];
void RMQ () {
for (int i = 1; i <= dfsord; i ++)
ST[i][0] = make_pair (val[i], belong[i]);
for (int j = 1; j <= 20; j ++)
for (int i = 1; i + (1 << j) - 1 <= dfsord; i ++)
ST[i][j] = ST[i][j - 1].first < ST[i + (1 << (j - 1))][j - 1].first ? ST[i][j - 1] : ST[i + (1 << (j - 1))][j - 1];
}
int LCA (int x, int y) {
int L = Dfn[x], R = Dfn[y];
if (L > R)
swap (L, R);
int k = log2 (R - L + 1);
return ST[L][k].first < ST[R - (1 << k) + 1][k].first ? ST[L][k].second : ST[R - (1 << k) + 1][k].second;
}
LL dist (int x, int y) {
int lca = LCA (x, y);
return Deep[x] + Deep[y] - (Deep[lca] << 1);
}

int father[MAXN]= {0};
bool Vis[MAXN]= {false};
int Size[MAXN]= {0};
int minv = INF, grvy;
int total;
void Grvy_Acqu (int root, int fa) {
Size[root] = 1;
int maxpart = 0;
for (int i = Head[root]; i; i = Link[i].next) {
int v = Link[i].to;
if (v == fa || Vis[v])
continue;
Grvy_Acqu (v, root);
Size[root] += Size[v];
maxpart = max (maxpart, Size[v]);
}
maxpart = max (maxpart, total - Size[root]);
if (maxpart < minv)
minv = maxpart, grvy = root;
}
LL sumo[MAXN]= {0}, sumt[MAXN]= {0}, sumfa[MAXN]= {0};
// sumo -> p子节点权值之和, sumt -> 子节点权值与距离的乘积到p之和, sumfa -> 子节点权值与距离的乘积到fa(点分树上)之和
void sums_Acqu (int root, int fa) {
sumo[grvy] += value[root], sumt[grvy] += value[root] * dist (root, grvy);
if (father[grvy])
sumfa[grvy] += value[root] * dist (root, father[grvy]);
for (int i = Head[root]; i; i = Link[i].next) {
int v = Link[i].to;
if (v == fa || Vis[v])
continue;
sums_Acqu (v, root);
}
}
void point_DAC (int p, int pre) {
minv = INF, grvy = p, total = Size[p];
Grvy_Acqu (p, 0);
Vis[grvy] = true, father[grvy] = pre;
sums_Acqu (grvy, 0);
int fgrvy = grvy;
for (int i = Head[fgrvy]; i; i = Link[i].next) {
int v = Link[i].to;
if (Vis[v])
continue;
point_DAC (v, fgrvy);
}
}

LL Query (int op) {
LL tsum = 0;
for (int p = op; p; p = father[p]) {
tsum += sumt[p];
if (p != op)
tsum += sumo[p] * dist (p, op);
if (father[p])
tsum -= sumo[p] * dist (father[p], op) + sumfa[p];
}
return tsum;
}
void Modify (int op, int delta) {
for (int p = op; p; p = father[p]) {
sumo[p] -= value[op] - delta;
sumt[p] -= value[op] * dist (p, op) - delta * dist (p, op);
if (father[p])
sumfa[p] -= value[op] * dist (father[p], op) - delta * dist (father[p], op);
}
sum -= value[op] - delta;
LL s = Query (op);
s1 += (delta - value[op]) * s;
value[op] = delta;
}

int getnum () {
int num = 0;
char ch = getchar ();
int isneg = 0;

while (! isdigit (ch)) {
if (ch == '-')
isneg = 1;
ch = getchar ();
}
while (isdigit (ch))
num = (num << 3) + (num << 1) + ch - '0', ch = getchar ();

return isneg ? - num : num;
}

int main () {
N = getnum (), Q = getnum ();
for (int i = 1; i < N; i ++) {
int u = getnum (), v = getnum ();
Insert (u, v), Insert (v, u);
}
for (int i = 1; i <= N; i ++)
value[i] = getnum ();
DFS (1, 0), RMQ ();
for (int i = 1; i <= N; i ++)
s1 += subtree[i] * (sum - subtree[i]);
Size[1] = N, point_DAC (1, 0);
/*cout << "Next----------------------" << endl;
for (int i = 1; i <= N; i ++)
cout << sumo[i] << ' ' << sumt[i] << ' ' << sumfa[i] << endl;
cout << "End-----------------------" << endl;*/
for (int Case = 1; Case <= Q; Case ++) {
int opt = getnum ();
if (opt == 1) {
int p = getnum (), delta = getnum ();
Modify (p, delta);
}
else if (opt == 2) {
int p = getnum ();
LL ans = (Query (p) + sum) * sum - s1;
printf ("%lld\n", ans);
}
}

return 0;
}

/*
4 5
1 2
2 3
2 4
4 3 2 1
2 2
1 1 3
2 3
1 2 4
2 4
*/

/*
4 1
1 2
2 3
2 4
4 3 2 1
2 1
*/