diff --git a/TangDou/Topic/HuanGenDp/POJ3585.cpp b/TangDou/Topic/HuanGenDp/POJ3585.cpp index 9b485b3..58de542 100644 --- a/TangDou/Topic/HuanGenDp/POJ3585.cpp +++ b/TangDou/Topic/HuanGenDp/POJ3585.cpp @@ -3,55 +3,55 @@ #include #include using namespace std; -const int N = 2e5 + 10; +const int N = 2e5 + 10, M = N << 1; + int T, n, du[N]; -struct node { - int to, w; -}; -vector c[N]; +// 链式前向星 +int e[M], h[N], idx, w[M], ne[M]; +void add(int a, int b, int c = 0) { + e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx++; +} -bool v[N]; int d[N], f[N]; void clean() { - memset(v, 0, sizeof v); + // 初始化链式前向星 + memset(h, -1, sizeof h); + idx = 0; + memset(du, 0, sizeof du); memset(d, 0, sizeof d); memset(f, 0, sizeof f); - - for (int i = 1; i <= n; i++) c[i].clear(); } -void dp(int x) { - v[x] = 1; - int size = c[x].size() - 1; - for (int i = 0; i <= size; i++) { - int to = c[x][i].to, w = c[x][i].w; - if (v[to]) continue; - dp(to); - if (du[to] == 1) - d[x] += w; +void dp(int u, int fa) { + for (int i = h[u]; ~i; i = ne[i]) { + int v = e[i]; + if (v == fa) continue; + dp(v, u); + if (du[v] == 1) + d[u] += w[i]; else - d[x] += min(d[to], w); + d[u] += min(d[v], w[i]); } return; } -void dfs(int x) { - v[x] = 1; - int size = c[x].size() - 1; - for (int i = 0; i <= size; i++) { - int to = c[x][i].to, w = c[x][i].w; - if (v[to]) continue; - if (du[x] == 1) - f[to] = d[to] + w; +void dfs(int u, int fa) { + for (int i = h[u]; ~i; i = ne[i]) { + int v = e[i]; + if (v == fa) continue; + if (du[u] == 1) + f[v] = d[v] + w[i]; else - f[to] = d[to] + min(f[x] - min(d[to], w), w); - dfs(to); + f[v] = d[v] + min(f[u] - min(d[v], w[i]), w[i]); + dfs(v, u); } return; } int main() { + // 加快读入 + ios::sync_with_stdio(false), cin.tie(0); cin >> T; while (T--) { cin >> n; @@ -59,15 +59,13 @@ int main() { int x, y, z; for (int i = 1; i < n; i++) { cin >> x >> y >> z; - c[x].push_back((node){y, z}); - c[y].push_back((node){x, z}); - du[x]++; - du[y]++; + add(x, y, z), add(y, x, z); + du[x]++, du[y]++; } - dp(1); + dp(1, 0); f[1] = d[1]; - memset(v, 0, sizeof v); - dfs(1); + + dfs(1, 0); int ans = 0; for (int i = 1; i <= n; i++) ans = max(ans, f[i]); printf("%d\n", ans);