CF461B Appleman and Tree

树形dp基础题

2.CF461B Appleman and Tree

题目描述

给你一棵有 nn 个节点的树,下标从 00 开始。

ii 个节点可以为白色或黑色。

现在你可以从中删去若干条边,使得剩下的每个部分恰有一个黑色节点。

问有多少种符合条件的删边方法,答案对 109+710^9+7 取模。

输入格式

第一行一个整数 n(1n105)n(1\leq n\leq 10^5),表示节点个数。

接下来一行 n1n-1 个整数 (p0,p1,,pn2,0pii)(p_0,p_1,\cdots,p_{n-2},0\leq p_i\leq i),表示树中有一条连接节点 pip_i 和节点 i+1i+1 的边。

接下来一行 nn 个整数 (x0,x1,,xn1,0xi1)(x_0,x_1,\cdots,x_{n-1},0\leq x_i\leq 1),若 xix_i11,则节点 ii 为黑色,否则为白色。

输出格式

第一行一个整数,表示符合条件的删边方法的方案数对 109+710^9+7 取模后的值。

题目分析

考虑树形dp,发现每个节点和其子树的关系有:1.节点所在的连通块没有1。2.节点所在的连通块有1。设dp[u][0/1]dp[u][0/1]为u节点所在的连通块没有1的方案,有1的方案。那么可以得到dp的状态转移方程:dp[u][1]=dp[u][1]dp[v][0]+dp[u][1]dp[v][1]+dp[u][0]dp[v][1]dp[u][1]=dp[u][1]*dp[v][0]+dp[u][1]*dp[v][1]+dp[u][0]*dp[v][1]三个项分别代表1.u合法,v不合法,不切断。2.u合法,v合法,切断。3.u不合法,v合法,不切断。同理,dp[u][0]=dp[u][0]dp[v][1]+dp[u][0]dp[v][0]dp[u][0]=dp[u][0]*dp[v][1]+dp[u][0]*dp[v][0]代表1.u不合法,v合法,切断。2.u不合法,v也不合法,不切断。每个节点处理一下初始值即可。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
//
// Created by mrx on 2022/9/27.
//
#include <vector>
#include <algorithm>
#include <iostream>
#include <array>

using ll = long long;

template<typename T>
T inverse(T a, T b) {
T u = 0, v = 1;
while (a != 0) {
T t = b / a;
b -= t * a;
std::swap(a, b);
u -= t * v;
std::swap(u, v);
}
assert(b == 1);
return u;
}

template<typename T>
T power(T a, int b) {
T ans = 1;
for (; b; a *= a, b >>= 1) {
if (b & 1)ans *= a;
}
return ans;
}

template<int Mod>
class Modular {
public:
using Type = int;

template<typename U>
static Type norm(U& x) {
Type v;
if (-Mod <= x && x < Mod) v = static_cast<Type>(x);
else v = static_cast<Type>(x % Mod);
if (v < 0) v += Mod;
return v;
}

constexpr Modular() : value() {}

int val() const { return value; }

Modular inv() const {
return Modular(inverse(value, Mod));
}

template<typename U>
Modular(const U& x) {
value = norm(x);
}

const Type& operator ()() const {
return value;
}

template<typename U>
explicit operator U() const {
return static_cast<U>(value);
}

Modular& operator +=(const Modular& other) {
if ((value += other.value) >= Mod) value -= Mod;
return *this;
}

Modular& operator -=(
const Modular& other) {
if ((value -= other.value) < 0) value += Mod;
return *this;
}

template<typename U>
Modular& operator +=(const U& other) { return *this += Modular(other); }

template<typename U>
Modular& operator -=(const U& other) { return *this -= Modular(other); }

Modular& operator ++() { return *this += 1; }

Modular& operator --() { return *this -= 1; }

Modular operator ++(int) {
Modular result(*this);
*this += 1;
return result;
}

Modular operator --(int) {
Modular result(*this);
*this -= 1;
return result;
}

Modular operator -() const { return Modular(-value); }

template<class ISTREAM_TYPE>
friend ISTREAM_TYPE& operator >>(ISTREAM_TYPE& is, Modular& rhs) {
ll v;
is >> v;
rhs = Modular(v);
return is;
}

template<class OSTREAM_TYPE>
friend OSTREAM_TYPE& operator <<(OSTREAM_TYPE& os, const Modular& rhs) {
return os << rhs.val();
}

Modular& operator *=(const Modular& rhs) {
value = ll(value) * rhs.value % Mod;
return *this;
}

Modular& operator /=(const Modular& other) { return *this *= Modular(inverse(other.value, Mod)); }

friend const Type& abs(const Modular& x) { return x.value; }

friend bool operator ==(const Modular& lhs, const Modular& rhs) { return lhs.value == rhs.value; }

friend bool operator <(const Modular& lhs, const Modular& rhs) { return lhs.value < rhs.value; }


bool operator ==(const Modular& rhs) { return *this == rhs.value; }

template<typename U>
bool operator ==(U rhs) { return *this == Modular(rhs); }

template<typename U>
friend bool operator ==(U lhs, const Modular& rhs) { return Modular(lhs) == rhs; }

bool operator !=(const Modular& rhs) { return *this != rhs; }

template<typename U>
bool operator !=(U rhs) { return *this != rhs; }

template<typename U>
friend bool operator !=(U lhs, const Modular& rhs) { return lhs != rhs; }

bool operator <(const Modular& rhs) { return this->value < rhs.value; }

Modular operator +(const Modular& rhs) { return Modular(*this) += rhs; }

template<typename U>
Modular operator +(U rhs) { return Modular(*this) += rhs; }

template<typename U>
friend Modular operator +(U lhs, const Modular& rhs) { return Modular(lhs) += rhs; }

Modular operator -(const Modular& rhs) { return Modular(*this) -= rhs; }

template<typename U>
Modular operator -(U rhs) { return Modular(*this) -= rhs; }

template<typename U>
friend Modular operator -(U lhs, const Modular& rhs) { return Modular(lhs) -= rhs; }

Modular operator *(const Modular& rhs) { return Modular(*this) *= rhs; }

template<typename U>
Modular operator *(U rhs) { return Modular(*this) *= rhs; }

template<typename U>
friend Modular operator *(U lhs, const Modular& rhs) { return Modular(lhs) *= rhs; }

Modular operator /(const Modular& rhs) { return Modular(*this) /= rhs; }

template<typename U>
Modular operator /(U rhs) { return Modular(*this) /= rhs; }

template<typename U>
friend Modular operator /(U lhs, const Modular& rhs) { return Modular(lhs) /= rhs; }

private:
Type value;
};

const int mod = 1e9 + 7;
using Z = Modular<mod>;

int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);

int n;
std::cin >> n;
std::vector<std::vector<int>> G(n);
for (int i = 1; i < n; ++i) {
int x;
std::cin >> x;
G[i].push_back(x);
G[x].push_back(i);
}
std::vector<int> color(n);
for (int i = 0; i < n; ++i) {
std::cin >> color[i];
}

std::vector<std::array<Z, 2>> dp(n);
std::function<void(int, int)> dfs = [&](int u, int fa) {
dp[u][color[u]] = 1;
for (auto v: G[u]) {
if (v == fa)continue;
dfs(v, u);
dp[u][1] = dp[u][0] * dp[v][1] + dp[u][1] * (dp[v][1] + dp[v][0]);
dp[u][0] = dp[u][0] * (dp[v][1] + dp[v][0]);
}
};


dfs(0, -1);
// for (int i = 0; i < n; ++i) {
// std::cout << dp[i][0] << ' ' << dp[i][1] << '\n';
// }
std::cout << dp[0][1] << '\n';
return 0;
}

CF461B Appleman and Tree
https://mrxyan6.github.io/2022/09/27/CF461B/
作者
mrx
发布于
2022年9月27日
许可协议