Skip to content

Commit 6e82e13

Browse files
authored
Create maximize-sum-of-weights-after-edge-removals.py
1 parent 039cdb8 commit 6e82e13

File tree

1 file changed

+129
-0
lines changed

1 file changed

+129
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Time: O(n)
2+
# Space: O(n)
3+
4+
import random
5+
6+
7+
# iterative dfs, quick select
8+
class Solution(object):
9+
def maximizeSumOfWeights(self, edges, k):
10+
"""
11+
:type edges: List[List[int]]
12+
:type k: int
13+
:rtype: int
14+
"""
15+
def nth_element(nums, n, compare=lambda a, b: a < b):
16+
def tri_partition(nums, left, right, target):
17+
i = left
18+
while i <= right:
19+
if compare(nums[i], target):
20+
nums[i], nums[left] = nums[left], nums[i]
21+
left += 1
22+
i += 1
23+
elif compare(target, nums[i]):
24+
nums[i], nums[right] = nums[right], nums[i]
25+
right -= 1
26+
else:
27+
i += 1
28+
return left, right
29+
30+
left, right = 0, len(nums)-1
31+
while left <= right:
32+
pivot_idx = random.randint(left, right)
33+
pivot_left, pivot_right = tri_partition(nums, left, right, nums[pivot_idx])
34+
if pivot_left <= n <= pivot_right:
35+
return
36+
elif pivot_left > n:
37+
right = pivot_left-1
38+
else: # pivot_right < n.
39+
left = pivot_right+1
40+
41+
def iter_dfs():
42+
cnt = [[0]*2 for _ in xrange(len(adj))]
43+
stk = [(1, 0, -1)]
44+
while stk:
45+
step, u, p = stk.pop()
46+
if step == 1:
47+
stk.append((2, u, p))
48+
for v, w in reversed(adj[u]):
49+
if v == p:
50+
continue
51+
stk.append((1, v, u))
52+
elif step == 2:
53+
curr = 0
54+
diff = []
55+
for v, w in adj[u]:
56+
if v == p:
57+
continue
58+
curr += cnt[v][0]
59+
diff.append(max((cnt[v][1]+w)-cnt[v][0], 0))
60+
if k-1 < len(diff):
61+
nth_element(diff, k-1, lambda a, b: a > b)
62+
cnt[u][0] = curr+sum(diff[i] for i in xrange(min(k, len(diff))))
63+
cnt[u][1] = curr+sum(diff[i] for i in xrange(min(k-1, len(diff))))
64+
return cnt[0][0]
65+
66+
adj = [[] for _ in xrange(len(edges)+1)]
67+
for u, v, w in edges:
68+
adj[u].append((v, w))
69+
adj[v].append((u, w))
70+
return iter_dfs()
71+
72+
73+
# Time: O(n)
74+
# Space: O(n)
75+
import random
76+
77+
78+
# dfs, quick select
79+
class Solution2(object):
80+
def maximizeSumOfWeights(self, edges, k):
81+
"""
82+
:type edges: List[List[int]]
83+
:type k: int
84+
:rtype: int
85+
"""
86+
def nth_element(nums, n, compare=lambda a, b: a < b):
87+
def tri_partition(nums, left, right, target):
88+
i = left
89+
while i <= right:
90+
if compare(nums[i], target):
91+
nums[i], nums[left] = nums[left], nums[i]
92+
left += 1
93+
i += 1
94+
elif compare(target, nums[i]):
95+
nums[i], nums[right] = nums[right], nums[i]
96+
right -= 1
97+
else:
98+
i += 1
99+
return left, right
100+
101+
left, right = 0, len(nums)-1
102+
while left <= right:
103+
pivot_idx = random.randint(left, right)
104+
pivot_left, pivot_right = tri_partition(nums, left, right, nums[pivot_idx])
105+
if pivot_left <= n <= pivot_right:
106+
return
107+
elif pivot_left > n:
108+
right = pivot_left-1
109+
else: # pivot_right < n.
110+
left = pivot_right+1
111+
112+
def dfs(u, p):
113+
result = 0
114+
diff = []
115+
for v, w in adj[u]:
116+
if v == p:
117+
continue
118+
cnt = dfs(v, u)
119+
result += cnt[0]
120+
diff.append(max((cnt[1]+w)-cnt[0], 0))
121+
if k-1 < len(diff):
122+
nth_element(diff, k-1, lambda a, b: a > b)
123+
return (result+sum(diff[i] for i in xrange(min(k, len(diff)))), result+sum(diff[i] for i in xrange(min(k-1, len(diff)))))
124+
125+
adj = [[] for _ in xrange(len(edges)+1)]
126+
for u, v, w in edges:
127+
adj[u].append((v, w))
128+
adj[v].append((u, w))
129+
return dfs(0, -1)[0]

0 commit comments

Comments
 (0)