4.3. itertools — 迭代器函数

目的:itertools 库包括了一系列处理序列型数据的函数。

在函数式编程语言如 Clojure,Haskell,APL 以及 SML 中,迭代是一个重要的概念。受到这些语言的启发,itertools 库提供了一些关于迭代运算的基础函数,通过对这些基本函数的有机组合,我们可以实现更加复杂的迭代算法。这种实现通常具有更好的时间,空间性能。

通常基于迭代的算法实现比基于列表的实现更加节省内存。因为知道迭代器需要一个数据,这个数据才会被生成并放入内存。这不同于列表结构,把所有的元素都生成了放在内存里。对于大数据集,这种惰性计算模型可以节省数据的交换以及其他一些副作用,从而提升程序的性能。

除了 itertools 库中定义的函数,这节的例子中也会用到一些关于迭代运算的内置函数。

合并与分割迭代器

chain() 函数将几个迭代器作为参数,并返回一个迭代器。这个迭代器将会依次遍历那些作为输入的迭代器。

itertools_chain.py

from itertools import *

for i in chain([1, 2, 3], ['a', 'b', 'c']):
    print(i, end=' ')
print()

运用 chain() 函数,可以很轻松的把多个序列合并成一个大的列表。

$ python3 itertools_chain.py

1 2 3 a b c

有时,甚至输入的那些迭代器也是动态生成的(你可能根本不知道有几个迭代器作为输入),这时我们可以用 chain.from_iterable() 来完成和 chain() 相类似的功能。

itertools_chain_from_iterable.py

from itertools import *

def make_iterables_to_chain():
    yield [1, 2, 3]
    yield ['a', 'b', 'c']

for i in chain.from_iterable(make_iterables_to_chain()):
    print(i, end=' ')
print()
$ python3 itertools_chain_from_iterable.py

1 2 3 a b c

内置函数 zip() 返回一个迭代器,这个迭代器将同时遍历多个输入迭代器,并返回一个由在这些迭代器中得到的元素所组合成的元祖。

itertools_zip.py

for i in zip([1, 2, 3], ['a', 'b', 'c']):
    print(i)

和这个库中的其他函数一样,返回一个只可以被遍历一次的可迭代对象。

$ python3 itertools_zip.py

(1, 'a')
(2, 'b')
(3, 'c')

zip() 将会在任意一个输入迭代器被遍历完时停止。如果想完整的遍历所有的输入迭代器(即使它们有不同的长度),我们可以用 zip_longest() 函数。

itertools_zip_longest.py

from itertools import *

r1 = range(3)
r2 = range(2)

print('zip stops early:')
print(list(zip(r1, r2)))

r1 = range(3)
r2 = range(2)

print('\nzip_longest processes all of the values:')
print(list(zip_longest(r1, r2)))

默认情况下,zip_longest() 函数将 None 作为缺省值。我们可以通过传入 fillvalue 参数来修改这个默认设定。

$ python3 itertools_zip_longest.py

zip stops early:
[(0, 0), (1, 1)]

zip_longest processes all of the values:
[(0, 0), (1, 1), (2, None)]

islice() 函数将把输入迭代器的一部分作为其输出的迭代器。

itertools_islice.py

from itertools import *

print('Stop at 5:')
for i in islice(range(100), 5):
    print(i, end=' ')
print('\n')

print('Start at 5, Stop at 10:')
for i in islice(range(100), 5, 10):
    print(i, end=' ')
print('\n')

print('By tens to 100:')
for i in islice(range(100), 0, 100, 10):
    print(i, end=' ')
print('\n')

islice() 和 slice 操作一样,将 startstop 以及 step 作为输入参数。其中 startstep 参数是可选的。

$ python3 itertools_islice.py

Stop at 5:
0 1 2 3 4

Start at 5, Stop at 10:
5 6 7 8 9

By tens to 100:
0 10 20 30 40 50 60 70 80 90

基于输入,tee() 函数返回多个(默认两个)独立的迭代器。

itertools_tee.py

from itertools import *

r = islice(count(), 5)
i1, i2 = tee(r)

print('i1:', list(i1))
print('i2:', list(i2))

在 Unix 系统中,tee 是一个非常基本的命令。这个命令将它的输入同时输出到一个给定的文件和屏幕(标准输出)。
类似的 tee() 函数将会基于它输入的迭代器返回多个和输入迭代器相同的迭代器,这些迭代器可以被数并列地输入不同的算法。

$ python3 itertools_tee.py

i1: [0, 1, 2, 3, 4]
i2: [0, 1, 2, 3, 4]

这些由 tee() 返回的迭代器会共享它们的输入。所以在用 tee() 创建了一些新的迭代器之后,通常原始的那个迭代器不应该再被使用。

itertools_tee_error.py

from itertools import *

r = islice(count(), 5)
i1, i2 = tee(r)

print('r:', end=' ')
for i in r:
    print(i, end=' ')
    if i > 1:
        break
print()

print('i1:', list(i1))
print('i2:', list(i2))

如果原始输入的迭代器中的一些数据被遍历了,这些数据将不会出现在那些新的(返回的)迭代器中。

$ python3 itertools_tee_error.py

r: 0 1 2
i1: [3, 4]
i2: [3, 4]

变换输入

内置函数 map() 会将一个函数分别作用到输入的一个迭代器的每个元素数,并将其以迭代器的形式返回。这个迭代器会遍历完这个输入的迭代器。

itertools_map.py


def times_two(x):
    return 2 * x

def multiply(x, y):
    return (x, y, x * y)

print('Doubles:')
for i in map(times_two, range(5)):
    print(i)

print('\nMultiples:')
r1 = range(5)
r2 = range(5, 10)
for i in map(multiply, r1, r2):
    print('{:d} * {:d} = {:d}'.format(*i))

print('\nStopping:')
r1 = range(5)
r2 = range(2)
for i in map(multiply, r1, r2):
    print(i)

在第一个例子中, lambda 函数将返回输入值的两倍。在第二个例子中,函数 multiply 返回两个输入参数的积。这个例子中返回的迭代器将这个函数作用到两个独立的迭代器上,其中的返回的元组由原始输入参数和它们的乘积组成。在第三个例子中,只有两个输出。因为第二个输入迭代器的比第一个短,只有两个元素。

$ python3 itertools_map.py

Doubles:
0
2
4
6
8

Multiples:
0 * 5 = 0
1 * 6 = 6
2 * 7 = 14
3 * 8 = 24
4 * 9 = 36

Stopping:
(0, 0, 0)
(1, 1, 1)

starmap() 这个函数和 map() 的作用很相似。但是 map() 函数的输入实际上是由多个(或单个)迭代器组成的元组,而 starmap() 遍历的是一个返回元组的单个迭代器。它会用 * 记号把元组分离成参数列表传入函数。

itertools_starmap.py

from itertools import *

values = [(0, 5), (1, 6), (2, 7), (3, 8), (4, 9)]

for i in starmap(lambda x, y: (x, y, x * y), values):
    print('{} * {} = {}'.format(*i))

也就是说,对 map() 函数来说,它可以有多个参数 (i1, i2),然后返回 f(i1, i2)。但对于 starmap() 来说,输入只有一个元组 i,它返回的是 f(*i)

$ python3 itertools_starmap.py

0 * 5 = 0
1 * 6 = 6
2 * 7 = 14
3 * 8 = 24
4 * 9 = 36

产生新的值

count() 函数返回一个产生一列连续整数的迭代器。我们可以传递一个参数来设定起始值。与内置函数 range() 不同,不需要给出一个参数来设定上限。

itertools_count.py

from itertools import *

for i in zip(count(1), ['a', 'b', 'c']):
    print(i)

因为列表参数是有限的,所以这个程序也是会终止的。

$ python3 itertools_count.py

(1, 'a')
(2, 'b')
(3, 'c')

count() 函数的起始和步长参数可以是任意可相加的数。

itertools_count_step.py

import fractions
from itertools import *

start = fractions.Fraction(1, 3)
step = fractions.Fraction(1, 3)

for i in zip(count(start, step), ['a', 'b', 'c']):
    print('{}: {}'.format(*i))

在这个例子中,起始值和步长都是 fraction 模块中的 Fraction 对象。

$ python3 itertools_count_step.py

1/3: a
2/3: b
1: c

cycle() 函数将会把输入的可迭代对象无限循环输出的迭代器。因此,这个函数会记住整个输入的迭代器,因此,这在输入迭代对象较长时会占用较多的内存。

itertools_cycle.py

from itertools import *

for i in zip(range(7), cycle(['a', 'b', 'c'])):
    print(i)

这里用了一个计数变量来跳出循环。

$ python3 itertools_cycle.py

(0, 'a')
(1, 'b')
(2, 'c')
(3, 'a')
(4, 'b')
(5, 'c')
(6, 'a')

repeat() 函数返回的迭代器会把一个值重复几次输出。

itertools_repeat.py

from itertools import *

for i in repeat('over-and-over', 5):
    print(i)

如果没有设定次数,那么 repeat() 函数会一直输出同一个值。我们也可以提供 time 参数来限制迭代的次数。

$ python3 itertools_repeat.py

over-and-over
over-and-over
over-and-over
over-and-over
over-and-over

repeat()zip() 以及 map() 组合起来使用是一种常用的,把一个值和其他迭代器组合在一起的方法。

itertools_repeat_zip.py

from itertools import *

for i, s in zip(count(), repeat('over-and-over', 5)):
    print(i, s)

这里,我们将一个常量通过 reapeat() 和一个迭代器组合在一起。

$ python3 itertools_repeat_zip.py

0 over-and-over
1 over-and-over
2 over-and-over
3 over-and-over
4 over-and-over

这个例子中,我们用 map() 函数将 0 到 4 的整数乘以了 2。

itertools_repeat_map.py

from itertools import *

for i in map(lambda x, y: (x, y, x * y), repeat(2), range(5)):
    print('{:d} * {:d} = {:d}'.format(*i))

这里的 repeat() 迭代器不需要给出迭代的次数。因为 map() 函数会自动在其中任意一个输入终止时终止。而这里的 range() 只返回五个值。

$ python3 itertools_repeat_map.py

2 * 0 = 0
2 * 1 = 2
2 * 2 = 4
2 * 3 = 6
2 * 4 = 8

过滤

dropwhile() 函数返回一个迭代器,其中的元素为原迭代器中,给定条件首次为假之后的所有元素。

itertools_dropwhile.py

from itertools import *

def should_drop(x):
    print('Testing:', x)
    return x < 1

for i in dropwhile(should_drop, [-1, 0, 1, 2, -2]):
    print('Yielding:', i)

dropwhile() 并不过滤所有元素;当条件首次为假后,原迭代器中剩余元素将全部返回。

$ python3 itertools_dropwhile.py

Testing: -1
Testing: 0
Testing: 1
Yielding: 1
Yielding: 2
Yielding: -2

与 dropwhile() 相对的是 takewhile()。它将原迭代器中,直到给定条件为假之前的所有元素,作为新的迭代器返回。

itertools_takewhile.py

from itertools import *

def should_take(x):
    print('Testing:', x)
    return x < 2

for i in takewhile(should_take, [-1, 0, 1, 2, -2]):
    print('Yielding:', i)

一旦 should_take() 返回 Falsetakewhile() 就停止处理输入。

$ python3 itertools_takewhile.py

Testing: -1
Yielding: -1
Testing: 0
Yielding: 0
Testing: 1
Yielding: 1
Testing: 2

内建函数 filter() 返回一个迭代器,只包含使测试函数为真的所有元素。

itertools_filter.py

from itertools import *

def check_item(x):
    print('Testing:', x)
    return x < 1

for i in filter(check_item, [-1, 0, 1, 2, -2]):
    print('Yielding:', i)

dropwhile()takewhile() 不同的是,filter() 返回前,所有元素都会被测试。

$ python3 itertools_filter.py

Testing: -1
Yielding: -1
Testing: 0
Yielding: 0
Testing: 1
Testing: 2
Testing: -2
Yielding: -2

filterfalse() 返回一个迭代器,其中只包含使测试函数为假的所有元素。

itertools_filterfalse.py

from itertools import *

def check_item(x):
    print('Testing:', x)
    return x < 1

for i in filterfalse(check_item, [-1, 0, 1, 2, -2]):
    print('Yielding:', i)

check_item() 中的内容与前例相同,所以此例中 filterfalse() 返回的结果恰好与前例相反。

$ python3 itertools_filterfalse.py

Testing: -1
Testing: 0
Testing: 1
Yielding: 1
Testing: 2
Yielding: 2
Testing: -2

compress() 提供了另一种方法来过滤序列。它不是调用一个测试函数,而是使用另外一个序列中的值来决定元素的取舍。

itertools_compress.py

from itertools import *

every_third = cycle([False, False, True])
data = range(1, 10)

for i in compress(data, every_third):
    print(i, end=' ')
print()

第一个参数是待处理的输入序列, 而第二个参数是一个选择器序列,其中的每个布尔值依次决定了是否取用输入序列中的元素(真则取用,假则舍弃)。

$ python3 itertools_compress.py

3 6 9

数据分组

groupby() 函数返回一个迭代器,其中的每个元素是有一个共同的键的一组值。这个例子中展示了根据一个属性来对相关数据进行分组的方法。

itertools_groupby_seq.py

import functools
from itertools import *
import operator
import pprint

@functools.total_ordering
class Point:

    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __repr__(self):
        return '({}, {})'.format(self.x, self.y)

    def __eq__(self, other):
        return (self.x, self.y) == (other.x, other.y)

    def __gt__(self, other):
        return (self.x, self.y) > (other.x, other.y)

# 为 Point 实例创建一个数据集
data = list(map(Point,
                cycle(islice(count(), 3)),
                islice(count(), 7)))
print('Data:')
pprint.pprint(data, width=35)
print()

# 将未排序的数据按X值分组
print('Grouped, unsorted:')
for k, g in groupby(data, operator.attrgetter('x')):
    print(k, list(g))
print()

# 对数据进行排序
data.sort()
print('Sorted:')
pprint.pprint(data, width=35)
print()

# 将排序后的数据按X值分组
print('Grouped, sorted:')
for k, g in groupby(data, operator.attrgetter('x')):
    print(k, list(g))
print()

为了使分组能正确工作,输入序列需要按键值排好序。

$ python3 itertools_groupby_seq.py

Data:
[(0, 0),
 (1, 1),
 (2, 2),
 (0, 3),
 (1, 4),
 (2, 5),
 (0, 6)]

Grouped, unsorted:
0 [(0, 0)]
1 [(1, 1)]
2 [(2, 2)]
0 [(0, 3)]
1 [(1, 4)]
2 [(2, 5)]
0 [(0, 6)]

Sorted:
[(0, 0),
 (0, 3),
 (0, 6),
 (1, 1),
 (1, 4),
 (2, 2),
 (2, 5)]

Grouped, sorted:
0 [(0, 0), (0, 3), (0, 6)]
1 [(1, 1), (1, 4)]
2 [(2, 2), (2, 5)]

联结输入

accumulate()函数将输入序列的第 n 和第 n+1 个元素传入给定函数,产出返回值。缺省情况下,函数将返回两个输入参数的和, 所以 accumulate() 可以用来得到一个数字序列的累加和。

itertools_accumulate.py

from itertools import *

print(list(accumulate(range(5))))
print(list(accumulate('abcde')))

若用在非整数序列上,结果取决于加操作对两个元素的意义。此代码中第二个例子展示了当 accumulate() 输入为字符串时,返回的结果为此字符串逐次变长的前缀部分。

$ python3 itertools_accumulate.py

[0, 1, 3, 6, 10]
['a', 'ab', 'abc', 'abcd', 'abcde']

你可以将 accumulate() 与任何接受两个参数的函数一起使用,来得到不同的结果。

itertools_accumulate_custom.py

from itertools import *

def f(a, b):
    print(a, b)
    return b + a + b

print(list(accumulate('abcde', f)))

这个例子中把字符串联结成一个个回文字符串。在每次迭代中调用 f() 时,打印 accumulate() 传给它的两个参数的值。

$ python3 itertools_accumulate_custom.py

a b
bab c
cbabc d
dcbabcd e
['a', 'bab', 'cbabc', 'dcbabcd', 'edcbabcde']

product() 常用来取代对多个序列的嵌套 for 循环,返回一个包含所有输入组合的笛卡儿积的迭代器。

itertools_product.py

from itertools import *
import pprint

FACE_CARDS = ('J', 'Q', 'K', 'A')
SUITS = ('H', 'D', 'C', 'S')

DECK = list(
    product(
        chain(range(2, 11), FACE_CARDS),
        SUITS,
    )
)

for card in DECK:
    print('{:>2}{}'.format(*card), end=' ')
    if card[1] == SUITS[-1]:
        print()

product() 产出的每个元素是一个元组,其中的成员依次取自传入的各序列。第一个返回的元组的成员依次是传入的各序列的第一个元素。最后一个传入 product() 的序列将首先迭代,然后是倒数第二个,依次类推。这样得到的结果将对第一个序列有序,然后对第二个序列有序,等等。

此例中,纸牌的序列将对牌面数字有序,然后是花色。

$ python3 itertools_product.py

 2H  2D  2C  2S
 3H  3D  3C  3S
 4H  4D  4C  4S
 5H  5D  5C  5S
 6H  6D  6C  6S
 7H  7D  7C  7S
 8H  8D  8C  8S
 9H  9D  9C  9S
10H 10D 10C 10S
 JH  JD  JC  JS
 QH  QD  QC  QS
 KH  KD  KC  KS
 AH  AD  AC  AS

要改变纸牌的顺序,把传入 product() 的顺序改一下就可以了。

itertools_product_ordering.py

from itertools import *
import pprint

FACE_CARDS = ('J', 'Q', 'K', 'A')
SUITS = ('H', 'D', 'C', 'S')

DECK = list(
    product(
        SUITS,
        chain(range(2, 11), FACE_CARDS),
    )
)

for card in DECK:
    print('{:>2}{}'.format(card[1], card[0]), end=' ')
    if card[1] == FACE_CARDS[-1]:
        print()

此例中,循环打印直到输出 A,而不是黑桃(S),则回车换行。

$ python3 itertools_product_ordering.py

 2H  3H  4H  5H  6H  7H  8H  9H 10H  JH  QH  KH  AH
 2D  3D  4D  5D  6D  7D  8D  9D 10D  JD  QD  KD  AD
 2C  3C  4C  5C  6C  7C  8C  9C 10C  JC  QC  KC  AC
 2S  3S  4S  5S  6S  7S  8S  9S 10S  JS  QS  KS  AS

要得到一个序列与自身的笛卡儿积,需指定输入要重复的次数。

itertools_product_repeat.py

from itertools import *

def show(iterable):
    for i, item in enumerate(iterable, 1):
        print(item, end=' ')
        if (i % 3) == 0:
            print()
    print()

print('Repeat 2:\n')
show(list(product(range(3), repeat=2)))

print('Repeat 3:\n')
show(list(product(range(3), repeat=3)))

因为重复一个序列相当于将此序列多次传入,所以 product() 产出的每个元组将含有与重复次数相等数量的成员。

$ python3 itertools_product_repeat.py

Repeat 2:

(0, 0) (0, 1) (0, 2)
(1, 0) (1, 1) (1, 2)
(2, 0) (2, 1) (2, 2)

Repeat 3:

(0, 0, 0) (0, 0, 1) (0, 0, 2)
(0, 1, 0) (0, 1, 1) (0, 1, 2)
(0, 2, 0) (0, 2, 1) (0, 2, 2)
(1, 0, 0) (1, 0, 1) (1, 0, 2)
(1, 1, 0) (1, 1, 1) (1, 1, 2)
(1, 2, 0) (1, 2, 1) (1, 2, 2)
(2, 0, 0) (2, 0, 1) (2, 0, 2)
(2, 1, 0) (2, 1, 1) (2, 1, 2)
(2, 2, 0) (2, 2, 1) (2, 2, 2)

permutations() 函数产出输入序列的所有给定长度的排列。默认返回全排列(与原序列长度相等)。

itertools_permutations.py

from itertools import *

def show(iterable):
    first = None
    for i, item in enumerate(iterable, 1):
        if first != item[0]:
            if first is not None:
                print()
            first = item[0]
        print(''.join(item), end=' ')
    print()

print('All permutations:\n')
show(permutations('abcd'))

print('\nPairs:\n')
show(permutations('abcd', r=2))

可用 r 参数来限定返回的每个排列的长度与数量。

$ python3 itertools_permutations.py

All permutations:

abcd abdc acbd acdb adbc adcb
bacd badc bcad bcda bdac bdca
cabd cadb cbad cbda cdab cdba
dabc dacb dbac dbca dcab dcba

Pairs:

ab ac ad
ba bc bd
ca cb cd
da db dc

要返回所有不重复的组合而不是排列, 使用 combinations()。如果输入序列的所有元素都不重复,输入中将不会有任何重复的值。

itertools_combinations.py

from itertools import *

def show(iterable):
    first = None
    for i, item in enumerate(iterable, 1):
        if first != item[0]:
            if first is not None:
                print()
            first = item[0]
        print(''.join(item), end=' ')
    print()

print('Unique pairs:\n')
show(combinations('abcd', r=2))

permutations() 不同,combinations()r 参数不能省略。

$ python3 itertools_combinations.py

Unique pairs:

ab ac ad
bc bd
cd

由于 combinations() 不会重复输入序列的元素,而有时又需要考虑包含重复元素的组合,对于这种情况,使用 combinations_with_replacement().

itertools_combinations_with_replacement.py

from itertools import *

def show(iterable):
    first = None
    for i, item in enumerate(iterable, 1):
        if first != item[0]:
            if first is not None:
                print()
            first = item[0]
        print(''.join(item), end=' ')
    print()

print('Unique pairs:\n')
show(combinations_with_replacement('abcd', r=2))

在输出中,每个输入的元素都将和自身配对,同时也和序列中的其他元素配对。

$ python3 itertools_combinations_with_replacement.py

Unique pairs:

aa ab ac ad
bb bc bd
cc cd
dd

另请参考

本文章首发在 LearnKu.com 网站上。
上一篇 下一篇
讨论数量: 0
发起讨论 只看当前版本


暂无话题~