树形dp基础题
2.CF461B Appleman and Tree
题目描述
给你一棵有 n n n 个节点的树,下标从 0 0 0 开始。
第 i i i 个节点可以为白色或黑色。
现在你可以从中删去若干条边,使得剩下的每个部分恰有一个黑色节点。
问有多少种符合条件的删边方法,答案对 1 0 9 + 7 10^9+7 1 0 9 + 7 取模。
输入格式
第一行一个整数 n ( 1 ≤ n ≤ 1 0 5 ) n(1\leq n\leq 10^5) n ( 1 ≤ n ≤ 1 0 5 ) ,表示节点个数。
接下来一行 n − 1 n-1 n − 1 个整数 ( p 0 , p 1 , ⋯ , p n − 2 , 0 ≤ p i ≤ i ) (p_0,p_1,\cdots,p_{n-2},0\leq p_i\leq i) ( p 0 , p 1 , ⋯ , p n − 2 , 0 ≤ p i ≤ i ) ,表示树中有一条连接节点 p i p_i p i 和节点 i + 1 i+1 i + 1 的边。
接下来一行 n n n 个整数 ( x 0 , x 1 , ⋯ , x n − 1 , 0 ≤ x i ≤ 1 ) (x_0,x_1,\cdots,x_{n-1},0\leq x_i\leq 1) ( x 0 , x 1 , ⋯ , x n − 1 , 0 ≤ x i ≤ 1 ) ,若 x i x_i x i 为 1 1 1 ,则节点 i i i 为黑色,否则为白色。
输出格式
第一行一个整数,表示符合条件的删边方法的方案数对 1 0 9 + 7 10^9+7 1 0 9 + 7 取模后的值。
题目分析
考虑树形dp,发现每个节点和其子树的关系有:1.节点所在的连通块没有1。2.节点所在的连通块有1。设d p [ u ] [ 0 / 1 ] dp[u][0/1] d p [ u ] [ 0 / 1 ] 为u节点所在的连通块没有1的方案,有1的方案。那么可以得到dp的状态转移方程:d p [ u ] [ 1 ] = d p [ u ] [ 1 ] ∗ d p [ v ] [ 0 ] + d p [ u ] [ 1 ] ∗ d p [ v ] [ 1 ] + d p [ u ] [ 0 ] ∗ d p [ v ] [ 1 ] dp[u][1]=dp[u][1]*dp[v][0]+dp[u][1]*dp[v][1]+dp[u][0]*dp[v][1] d p [ u ] [ 1 ] = d p [ u ] [ 1 ] ∗ d p [ v ] [ 0 ] + d p [ u ] [ 1 ] ∗ d p [ v ] [ 1 ] + d p [ u ] [ 0 ] ∗ d p [ v ] [ 1 ] 三个项分别代表1.u合法,v不合法,不切断。2.u合法,v合法,切断。3.u不合法,v合法,不切断。同理,d p [ u ] [ 0 ] = d p [ u ] [ 0 ] ∗ d p [ v ] [ 1 ] + d p [ u ] [ 0 ] ∗ d p [ v ] [ 0 ] dp[u][0]=dp[u][0]*dp[v][1]+dp[u][0]*dp[v][0] d p [ u ] [ 0 ] = d p [ u ] [ 0 ] ∗ d p [ v ] [ 1 ] + d p [ u ] [ 0 ] ∗ d p [ v ] [ 0 ] 代表1.u不合法,v合法,切断。2.u不合法,v也不合法,不切断。每个节点处理一下初始值即可。
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 #include <vector> #include <algorithm> #include <iostream> #include <array> using ll = long long ;template <typename T>T inverse (T a, T b) { T u = 0 , v = 1 ; while (a != 0 ) { T t = b / a; b -= t * a; std::swap (a, b); u -= t * v; std::swap (u, v); } assert (b == 1 ); return u; }template <typename T>T power (T a, int b) { T ans = 1 ; for (; b; a *= a, b >>= 1 ) { if (b & 1 )ans *= a; } return ans; }template <int Mod>class Modular {public : using Type = int ; template <typename U> static Type norm (U& x) { Type v; if (-Mod <= x && x < Mod) v = static_cast <Type>(x); else v = static_cast <Type>(x % Mod); if (v < 0 ) v += Mod; return v; } constexpr Modular () : value() { } int val () const { return value; } Modular inv () const { return Modular (inverse (value, Mod)); } template <typename U> Modular (const U& x) { value = norm (x); } const Type& operator () () const { return value; } template <typename U> explicit operator U () const { return static_cast <U>(value); } Modular& operator +=(const Modular& other) { if ((value += other.value) >= Mod) value -= Mod; return *this ; } Modular& operator -=( const Modular& other) { if ((value -= other.value) < 0 ) value += Mod; return *this ; } template <typename U> Modular& operator +=(const U& other) { return *this += Modular (other); } template <typename U> Modular& operator -=(const U& other) { return *this -= Modular (other); } Modular& operator ++() { return *this += 1 ; } Modular& operator --() { return *this -= 1 ; } Modular operator ++(int ) { Modular result (*this ) ; *this += 1 ; return result; } Modular operator --(int ) { Modular result (*this ) ; *this -= 1 ; return result; } Modular operator -() const { return Modular (-value); } template <class ISTREAM_TYPE > friend ISTREAM_TYPE& operator >>(ISTREAM_TYPE& is, Modular& rhs) { ll v; is >> v; rhs = Modular (v); return is; } template <class OSTREAM_TYPE > friend OSTREAM_TYPE& operator <<(OSTREAM_TYPE& os, const Modular& rhs) { return os << rhs.val (); } Modular& operator *=(const Modular& rhs) { value = ll (value) * rhs.value % Mod; return *this ; } Modular& operator /=(const Modular& other) { return *this *= Modular (inverse (other.value, Mod)); } friend const Type& abs (const Modular& x) { return x.value; } friend bool operator ==(const Modular& lhs, const Modular& rhs) { return lhs.value == rhs.value; } friend bool operator <(const Modular& lhs, const Modular& rhs) { return lhs.value < rhs.value; } bool operator ==(const Modular& rhs) { return *this == rhs.value; } template <typename U> bool operator ==(U rhs) { return *this == Modular (rhs); } template <typename U> friend bool operator ==(U lhs, const Modular& rhs) { return Modular (lhs) == rhs; } bool operator !=(const Modular& rhs) { return *this != rhs; } template <typename U> bool operator !=(U rhs) { return *this != rhs; } template <typename U> friend bool operator !=(U lhs, const Modular& rhs) { return lhs != rhs; } bool operator <(const Modular& rhs) { return this ->value < rhs.value; } Modular operator +(const Modular& rhs) { return Modular (*this ) += rhs; } template <typename U> Modular operator +(U rhs) { return Modular (*this ) += rhs; } template <typename U> friend Modular operator +(U lhs, const Modular& rhs) { return Modular (lhs) += rhs; } Modular operator -(const Modular& rhs) { return Modular (*this ) -= rhs; } template <typename U> Modular operator -(U rhs) { return Modular (*this ) -= rhs; } template <typename U> friend Modular operator -(U lhs, const Modular& rhs) { return Modular (lhs) -= rhs; } Modular operator *(const Modular& rhs) { return Modular (*this ) *= rhs; } template <typename U> Modular operator *(U rhs) { return Modular (*this ) *= rhs; } template <typename U> friend Modular operator *(U lhs, const Modular& rhs) { return Modular (lhs) *= rhs; } Modular operator /(const Modular& rhs) { return Modular (*this ) /= rhs; } template <typename U> Modular operator /(U rhs) { return Modular (*this ) /= rhs; } template <typename U> friend Modular operator /(U lhs, const Modular& rhs) { return Modular (lhs) /= rhs; }private : Type value; };const int mod = 1e9 + 7 ;using Z = Modular<mod>;int main () { std::ios::sync_with_stdio (false ); std::cin.tie (nullptr ); std::cout.tie (nullptr ); int n; std::cin >> n; std::vector<std::vector<int >> G (n); for (int i = 1 ; i < n; ++i) { int x; std::cin >> x; G[i].push_back (x); G[x].push_back (i); } std::vector<int > color (n) ; for (int i = 0 ; i < n; ++i) { std::cin >> color[i]; } std::vector<std::array<Z, 2>> dp (n); std::function<void (int , int )> dfs = [&](int u, int fa) { dp[u][color[u]] = 1 ; for (auto v: G[u]) { if (v == fa)continue ; dfs (v, u); dp[u][1 ] = dp[u][0 ] * dp[v][1 ] + dp[u][1 ] * (dp[v][1 ] + dp[v][0 ]); dp[u][0 ] = dp[u][0 ] * (dp[v][1 ] + dp[v][0 ]); } }; dfs (0 , -1 ); std::cout << dp[0 ][1 ] << '\n' ; return 0 ; }