defsieve_of_eratosthenes(n): primes = [True] * (n + 1) primes[0] = primes[1] = False for i inrange(2, int(n**0.5) + 1): if primes[i]: for j inrange(i*i, n+1, i): primes[j] = False return [i for i inrange(n+1) if primes[i]]
2.2 最大公约数(GCD)
欧几里得算法(辗转相除法)
1 2 3 4
defgcd(a, b): while b: a, b = b, a % b return a
2.3 组合数学
杨辉三角
1 2 3 4 5 6 7 8 9 10
defgenerate_pascal_triangle(n): triangle = [[1]] for _ inrange(n-1): prev_row = triangle[-1] new_row = [1] for i inrange(len(prev_row)-1): new_row.append(prev_row[i] + prev_row[i+1]) new_row.append(1) triangle.append(new_row) return triangle
2.4 数论
快速幂算法
1 2 3 4 5 6 7 8
defpow_mod(base, exponent, mod): result = 1 while exponent > 0: if exponent & 1: result = (result * base) % mod exponent >>= 1 base = (base * base) % mod return result
3. LeetCode 实战应用
3.1 LeetCode 204. Count Primes
计算小于非负整数 n 的质数的数量。
1 2 3 4 5 6 7 8 9 10 11 12 13 14
classSolution: defcountPrimes(self, n: int) -> int: if n <= 2: return0 primes = [True] * n primes[0] = primes[1] = False for i inrange(2, int(n**0.5) + 1): if primes[i]: for j inrange(i*i, n, i): primes[j] = False returnsum(primes)
3.2 LeetCode 69. Sqrt(x)
实现 int sqrt(int x) 函数。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
classSolution: defmySqrt(self, x: int) -> int: if x < 2: return x left, right = 2, x // 2 while left <= right: mid = left + (right - left) // 2 squared = mid * mid if squared == x: return mid elif squared < x: left = mid + 1 else: right = mid - 1 return right
3.3 LeetCode 50. Pow(x, n)
实现 pow(x, n) ,即计算 x 的 n 次幂函数。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
classSolution: defmyPow(self, x: float, n: int) -> float: if n < 0: x = 1 / x n = -n result = 1 current_product = x while n > 0: if n % 2 == 1: result *= current_product current_product *= current_product n //= 2 return result
defchinese_remainder_theorem(nums, remainders): total = 0 product = reduce(lambda a, b: a*b, nums) for num, remainder inzip(nums, remainders): p = product // num total += remainder * mod_inverse(p, num) * p return total % product
defmod_inverse(a, m): defextended_gcd(a, b): if a == 0: return b, 0, 1 else: gcd, x, y = extended_gcd(b % a, a) return gcd, y - (b // a) * x, x gcd, x, _ = extended_gcd(a, m) if gcd != 1: raise Exception('模逆不存在') else: return x % m
deffft(a): n = len(a) if n <= 1: return a omega = cmath.exp(2j * cmath.pi / n) a0 = fft(a[0::2]) a1 = fft(a[1::2]) return [a0[k] + omega**k * a1[k] for k inrange(n//2)] + \ [a0[k] - omega**k * a1[k] for k inrange(n//2)]
defifft(a): n = len(a) a = fft([x.conjugate() for x in a]) return [x.conjugate() / n for x in a]
defmultiply(a, b): n = 1 while n < len(a) + len(b): n <<= 1 a.extend([0] * (n - len(a))) b.extend([0] * (n - len(b))) fa = fft(a) fb = fft(b) fc = [fa[i] * fb[i] for i inrange(n)] c = [round(x.real) for x in ifft(fc)] while c[-1] == 0: c.pop() return c