publicstaticvoidmain(final String[] args){ finalint m = 3; final List<Integer> nums = Arrays.asList(1, 2, 3, 4, 5); final Solution so = new Solution(); final List<List<Integer>> res = so.permute(nums, m); System.out.println(res); System.out.println(res.size()); }
}
classSolution{
List<Integer> nums = new ArrayList<>(); static List<List<Integer>> res = new ArrayList<>(); List<Boolean> visited = new ArrayList<>(); int m = 0;
/** * 带重复数字的全排列生成 * */ public List<List<Integer>> permute(final List<Integer> nums, finalint m) {
finalint n = nums.size(); if(nums.isEmpty()) return res;
defpermuteUnique(nums, m):# 1 """permutation with repeatable input :param nums: List[int] :return: List[List[int]] """ ifnot nums: return# 2 res = [] # 3 n = len(nums) # 4
defbacktrack2(nums, path, depth):# 5 """backtrack based on the digging out select number :param nums: List[int] :param path: List[int] :param depth: depth :return: None """ if depth == m and path notin res: res.append(path[:]) # 6 for i inrange(len(nums)): # 7 backtrack2(nums[:i]+nums[i+1:], path+[nums[i]], depth+1) # 8
给定一个数组 nums (无重复、不一定连续) ,求 k 个子元素组成的所有组合,14行代码实现。 written by hand require
1 2 3 4 5 6 7 8 9 10 11 12 13 14
defsubset(nums, k): defbacktrack(start = 0, path = []): cur_len = len(path) if cur_len == k: res.append(path[:]) return for i inrange(start, n - (k - cur_len) + 1): # 剪枝策略,详见LeetCode 77题 backtrack(i + 1, path + [nums[i]]) n = len(nums) if k == n: return [nums] if n <= 0or k <= 0or k > n: return [[]] res = [] backtrack() return res
defsubsets(nums): """calculate all kinds of subsets based on input list :param nums: List[int] :param k: int :return: List[List[int]] """ ifnot nums: return [] n = len(nums) res = [] nums.sort() defbacktrack(idx, n, path): res.append(path[:]) # 无需终止条件,一直往res种加回溯结果就行。 for i inrange(idx, n): backtrack(i + 1, n, path + [nums[i]]) backtrack(0, n, []) return re
遍历循环5行代码实现
1 2 3 4 5 6 7 8 9 10
defsubsets(nums): """calculate all kinds of subsets based on input list :param nums: List[int] :param k: int :return: List[List[int]] """ res = [[]] for num in nums: res += [curr + [num] for curr in res] return res
defsubsetWithDup(nums, k): """calculate subset with k elements based on input list :param nums: List[int] :param k: int :return: List[List[int]] """ defbacktrack(start = 0, path = []): cur_len = len(path) if cur_len == k: if path notin res: # 此处去重有效 res.append(path[:]) return for i inrange(start, n - (k - cur_len) + 1): # 剪枝策略,详见LeetCode 77题 # if i > start and nums[i] == nums[i-1]: # 此处去重无效 # continue backtrack(i + 1, path + [nums[i]]) n = len(nums) if k == n: return [nums] if n <= 0or k <= 0or k > n: return [[]] res = [] backtrack() return res
defsubsetsWithDup(nums): """calculate all kinds of subsets based on input list :param nums: List[int] :param k: int :return: List[List[int]] """ ifnot nums: return [] res = [] nums.sort() # 去重都要做排序
deftraversal(): """calculate all kinds of subsets based on traversal """ # 循环开始前先对nums进行排序 output = [[]] tmp = [] for i inrange(len(nums)): if i > 0and nums[i - 1] == nums[i]: tmp = [curr + [nums[i]] for curr in tmp] else: tmp = [curr + [nums[i]] for curr in output] output += tmp return output
defbacktrack(start, n, path): """calculate all kinds of subsets based on traversal :param start: int :param n: int :param path: List[int] :return: None """ # if path not in res: # 去重处1(35.29%) # res.append(path[:]) res.append(path[:]) # 无需终止条件,一直往res种加回溯结果就行。 for i inrange(start, n): if i > start and nums[i] == nums[i - 1]: # 去重处2(76.81%) continue backtrack(i + 1, n, path + [nums[i]])
# 基于回溯的方法 n = len(nums) backtrack(0, n, []) # 35.29%/76.81%
# 声明三个hash map用来保存存储关系 实现约束编程 rows = [defaultdict(int) for i inrange(N)] columns = [defaultdict(int) for i inrange(N)] boxes = [defaultdict(int) for i inrange(N)]
for index inrange(begin, length): # 剪枝判断,如果当前位置的数字大于residue,由于是sort了的,那后面也不用看了,接break掉。这里是剪枝。 if (sum(path) + candidates[index]) > target: break # 在这里进行一个判断,过滤掉那些紧随其后一样的数字,这么写就是为了防止不同数字重复解的出现。这里是去重复。 if index > begin and candidates[index-1] == candidates[index]: continue # index+1,这么写就是为了防止数字本身重复。 backtrack(candidates + [index], index+1, path, res, length)
直接回溯:不设置 is_validate 函数和 used 变量,每次都从下一个位置开始遍历,这样就跳过了重复的情况。
加上剪枝:剪枝要去掉的多余步骤一般出现在循环中,如果已经选了的元素都放到 select 中,一共要选择 k 个元素。那么循环干的事情,是从 [i, n] 这个区间里(注意,左右都是闭区间),找到 k - len(selects)个元素。 i 有一个上限。那这个上限 max(i) 是 n - (k -len(selects)) + 1。所以,我们的剪枝过程就是:把for循环的终止条件从 i < n+1 改成 i < n - (k - len(selects)) + 2。
1 2 3 4
for i inrange(first, n - (k - len(selects)) + 2): pre.append(i) backtrack(i + 1, selects) pre.pop()
也可以用字典序 (二进制排序) 组合来解决,不过理解起来不是很简单。
关键代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
classSolution: defcombine(self, n: int, k: int) -> List[List[int]]: defbacktrack(first=1, selects=[]): iflen(selects) == k: res.append(selects[:]) return # for i in range(first, n + 1): # 剪枝之前 66.70% for i inrange(first, n - (k - len(selects)) + 2): # 剪枝之后 97.27% selects.append(i) backtrack(i+1, selects) selects.pop() # 特判 if n <= 0or k <= 0or k > n: return [] if k == n: return [list(range(1, n+1))] res = [] backtrack() return res
# 思路1 defbacktrack1(idx, n, temp_list): if temp_list notin res: # 此处去重 res.append(temp_list) for i inrange(idx, n): backtrack1(i + 1, n, temp_list + [nums[i]]) # 思路2 defbacktrack2(idx, n, temp_list): res.append(temp_list) for i inrange(idx, n): if i > idx and nums[i] == nums[i - 1]: # 此处去重 continue backtrack2(i + 1, n, temp_list + [nums[i]])
遍历递归
1 2 3 4 5 6 7 8 9 10
# 循环开始前先对nums进行排序 nums.sort() tmp = [] for i inrange(len(nums)): if i > 0and nums[i - 1] == nums[i]: tmp = [curr + [nums[i]] for curr in tmp] else: tmp = [curr + [nums[i]] for curr in output] output += tmp return output
classSolution: defexist(self, borad, word): defbacktrack(borad, word, depth, row, col): if depth == len(word): # 终止 returnTrue if row < 0or col < 0or row >= rows or col >= cols or borad[row][col] != word[depth]: returnFalse char = borad[row][col] borad[row][col] = "#"# 选择 res = backtrack(borad, word, depth+1, row-1, col) or \ backtrack(borad, word, depth+1, row+1, col) or \ backtrack(borad, word, depth+1, row, col-1) or \ backtrack(borad, word, depth+1, row, col+1) # 回溯 borad[row][col] = char # 终止 return res
ifnotlen(borad) ornotlen(borad[0]): returnFalse rows = len(borad) cols = len(borad[0]) for row inrange(rows): for col inrange(cols): if backtrack(borad, word, 0, row, col): returnTrue returnFalse
classSolution: defgenerateTrees(self, n: int) -> List[TreeNode]: if n == 0: return [] return self.dfs(1, n)
defdfs(self, start, end): if start > end: return [None] res = [] for rootval inrange(start, end+1): LeftTree = self.dfs(start, rootval-1) RightTree = self.dfs(rootval+1, end) for i in LeftTree: for j in RightTree: root = TreeNode(rootval) root.left = i root.right = j res.append(root) return res