Kimi    LeetCode 3410. 删除所有值为某个元素后的最大子数组和 Python3实现

以下是 LeetCode 3410. 删除所有值为某个元素后的最大子数组和 的 Python3 实现。

题目理解

给定整数数组 `nums`,可以至多一次选择某个值 `X`,删除数组中所有值为 `X` 的元素(删除后数组必须非空),求能得到数组的最大子数组和。

- 示例1:`nums = [-3,2,-2,-1,3,-2,3]`,删除所有 `-2` 后得到 `[-3,2,-1,3,3]`,最大子数组和为 `2+(-1)+3+3 = 7`。
- 示例2:`nums = [1,2,3,4]`,不删除任何元素,最大子数组和为 `10`。

解题思路

这道题的核心是线段树维护最大子段和 :

1. 不删除时:就是经典的「最大子数组和」(LeetCode 53),可用动态规划或线段树解决。
2. 删除某个值 X 时:等价于将数组中所有值为 `X` 的位置设为 `0`,然后求最大子数组和。
3. 枚举所有可能的 X:用哈希表按值分组记录下标,对每个不同的值,将其所有出现位置在线段树中更新为 `0`,查询全局最大子段和,然后恢复。

线段树每个节点需要维护 4 个信息:
- `sm`:区间总和
- `lv`:从区间左边界开始的最大子段和
- `rv`:以区间右边界结束的最大子段和
- `ans`:区间内的最大子段和

合并两个区间时,最大子段和有三种来源:完全在左区间、完全在右区间、横跨两个区间。

注意:Python 中这道题卡常,需要手写 `max` 函数(用 `lambda` 或普通函数),否则用内置 `max` 会超时 。

Python3 代码

```python
class Node:
def __init__(self, sm, lv, rv, ans):
self.sm = sm # 区间和
self.lv = lv # 从左边界开始的最大子段和
self.rv = rv # 以右边界结束的最大子段和
self.ans = ans # 区间内的最大子段和

class Solution:
def maxSubarraySum(self, nums):
n = len(nums)

# 特殊情况:全是负数时,子段必须非空,只能选最大的负数
mx = -10**9
for x in nums:
mx = mx if mx > x else x
if mx <= 0:
return mx

# 手动比大小,效率更高,不这么写 Python 会超时
def mymax(a, b):
return b if b > a else a

# 线段树维护最大子段和
tree = [Node(0, 0, 0, 0) for _ in range(2 << n.bit_length())]

def merge(nl, nr):
"""合并两个区间"""
return Node(
nl.sm + nr.sm,
mymax(nl.lv, nl.sm + nr.lv), # 左区间左半部分,或左边全选+右边左部分
mymax(nl.rv + nr.sm, nr.rv), # 右区间右半部分,或左边右部分+右边全选
mymax(mymax(nl.ans, nr.ans), nl.rv + nr.lv) # 左区间、右区间、或横跨
)

def initNode(val):
return Node(val, val, val, val)

def build(id, l, r):
if l == r:
tree[id] = initNode(nums[l])
else:
nxt = id << 1
mid = (l + r) >> 1
build(nxt, l, mid)
build(nxt + 1, mid + 1, r)
tree[id] = merge(tree[nxt], tree[nxt + 1])

def modify(id, l, r, pos, val):
"""单点更新:将位置 pos 的值改为 val"""
if l == r:
tree[id] = initNode(val)
else:
nxt = id << 1
mid = (l + r) >> 1
if pos <= mid:
modify(nxt, l, mid, pos, val)
else:
modify(nxt + 1, mid + 1, r, pos, val)
tree[id] = merge(tree[nxt], tree[nxt + 1])

# 构建线段树
build(1, 0, n - 1)

# 不删除任何元素的答案
ans = tree[1].ans

# 按值分组,记录每个值出现的所有下标
from collections import defaultdict
mp = defaultdict(list)
for i in range(n):
mp[nums[i]].append(i)

# 枚举删除哪种数
for val, indices in mp.items():
if len(indices) != n: # 删除后需要保证数组不为空
# 把这种数都改成 0(模拟删除)
for x in indices:
modify(1, 0, n - 1, x, 0)
# 更新答案
ans = mymax(ans, tree[1].ans)
# 恢复
for x in indices:
modify(1, 0, n - 1, x, val)

return ans
```

复杂度分析

- 时间复杂度:O(n \log n),构建线段树 O(n \log n),枚举每个不同值并更新 O(k \cdot m \cdot \log n),其中 k 是不同值的个数,m 是该值出现次数,总和为 O(n \log n)。
- 空间复杂度:O(n),线段树需要约 4n 个节点。