# coding: utf-8
# 2020/1/13 @ tongshiwei
import json
import os
import warnings
import functools
import queue
import threading
# import multiprocessing as mp
import multiprocess as mp
from tqdm import tqdm
from longling import wf_open, loading
__all__ = ["BaseIter", "MemoryIter", "LoopIter", "AsyncLoopIter", "AsyncIter", "CacheAsyncLoopIter",
"iterwrap"]
class Register(dict): # pragma: no cover
def add(self, cls):
if cls.__name__ in self:
warnings.warn("key %s has already existed, which will be overridden" % cls.__name__)
self[cls.__name__] = cls
return cls
register = Register()
[文档]@register.add
class BaseIter(object):
"""
迭代器
Notes
------
* 如果 src 是一个迭代器实例,那么在一轮迭代之后,迭代器里的内容就被迭代完了,将无法重启。
* 如果想使得迭代器可以一直被循环迭代,那么 src 应当是迭代器实例的生成函数, 同时在每次循环结束后,调用reset()
* 如果 src 没有 __length__,那么在第一次迭代结束前,无法对 BaseIter 的实例调用 len() 函数
Examples
--------
.. code-block:: python
# 单次迭代后穷尽内容
with open("demo.txt") as f:
bi = BaseIter(f)
for line in bi:
pass
# 可多次迭代
def open_file():
with open("demo.txt") as f:
for line in f:
yield line
bi = BaseIter(open_file)
for _ in range(5):
for line in bi:
pass
bi.reset()
# 简化的可多次迭代的写法
@BaseIter.wrap
def open_file():
with open("demo.txt") as f:
for line in f:
yield line
bi = open_file()
for _ in range(5):
for line in bi:
pass
bi.reset()
"""
def __init__(self, src, fargs=None, fkwargs=None, length=None, *args, **kwargs):
self._reset = src
self._fargs: list = fargs if fargs is not None else []
self._fkwargs: dict = fkwargs if fkwargs is not None else {}
self._data = None
self._length = length
self._count = 0
self.init()
def __next__(self):
try:
self._count += 1
return next(self._data)
except StopIteration:
self._count -= 1
self._set_length()
raise StopIteration
def __iter__(self):
return self
def __len__(self):
if self._length is None:
raise TypeError("length is unknown")
return self._length
def init(self):
self.reset()
def reset(self):
_data = self._reset(*self._fargs, **self._fkwargs) if callable(self._reset) else self._reset
try:
self._length = len(_data)
except TypeError:
self._set_length()
self._data = iter(_data)
def _set_length(self):
if self._length is None and self._count > 0:
self._length = self._count
@classmethod
def wrap(cls, f):
@functools.wraps(f)
def _f(*fargs, **fkwargs):
return cls(f, fargs=fargs, fkwargs=fkwargs)
return _f
[文档]@register.add
class MemoryIter(BaseIter):
"""
内存迭代器
会将所有迭代器内容装载入内存
"""
def __init__(self, src, fargs=None, fkwargs=None, length=None, prefetch=False, *args, **kwargs):
self._memory_data = []
self._in_memory = True
super(MemoryIter, self).__init__(src, fargs, fkwargs, length)
self.prefetch = prefetch
if self.prefetch:
self._prefetch()
def _prefetch(self):
for _data in tqdm(self._data, "prefetching data"):
self._count += 1
self._memory_data.append(_data)
self._set_length()
self._in_memory = False
self._data = iter(self._memory_data)
def __next__(self):
try:
self._count += 1
elem = next(self._data)
if self._in_memory is True:
self._memory_data.append(elem)
return elem
except StopIteration:
self._count -= 1
self.reset()
self._in_memory = False
raise StopIteration
def reset(self):
if not self._memory_data:
super(MemoryIter, self).reset()
else:
self._data = iter(self._memory_data)
self._set_length()
[文档]@register.add
class LoopIter(BaseIter):
"""
循环迭代器
每次迭代后会进行自动的 reset() 操作
"""
def __init__(self, src, fargs=None, fkwargs=None, length=None, *args, **kwargs):
super(LoopIter, self).__init__(src, fargs, fkwargs, length)
def __next__(self):
try:
self._count += 1
return next(self._data)
except StopIteration:
self._count -= 1
self.reset()
raise StopIteration
def produce(data, produce_queue, fargs, fkwargs):
_stop = False
data = data(*fargs, **fkwargs) if callable(data) else data
try:
for _data in data:
produce_queue.put(_data)
raise StopIteration
except StopIteration as e:
if not _stop:
_stop = True
produce_queue.put(e)
except Exception as e: # pragma: no cover
if not _stop:
_stop = True
produce_queue.put(e)
[文档]@register.add
class AsyncLoopIter(LoopIter):
"""
异步循环迭代器,适用于加载文件
数据的读入和数据的使用迭代是异步的。reset() 之后会进行数据预取
"""
def __init__(self, src, fargs=None, fkwargs=None, tank_size=8, timeout=None, level="t"):
self.thread = None
self._size = tank_size
self.mode = self._mode_map(level)
if self.mode == "t":
self.queue_cls = queue.Queue
self.thread_cls = threading.Thread
elif self.mode == "p":
self.queue_cls = mp.Queue
self.thread_cls = mp.Process
else: # pragma: no cover
raise TypeError("unknown mode: %s" % self.mode)
self.queue = self.queue_cls(self._size)
self._timeout = timeout
super(AsyncLoopIter, self).__init__(src, fargs, fkwargs)
@classmethod
def _mode_map(cls, mode):
map_dict = {
"p": "p", "processing": "p", "multiprocessing": "p",
"t": "t", "thread": "t", "threading": "t",
}
return map_dict[mode]
def reset(self):
if self.mode == "p":
self._set_length()
self._data = self._reset
else:
super(AsyncLoopIter, self).reset()
if self.thread is not None:
self.thread.join()
self.thread = self.thread_cls(
target=produce,
kwargs=dict(data=self._data, produce_queue=self.queue, fargs=self._fargs, fkwargs=self._fkwargs),
daemon=True
)
self.thread.start()
def __next__(self):
if self.queue is not None:
item = self.queue.get()
else: # pragma: no cover
raise StopIteration
if isinstance(item, Exception):
if isinstance(item, StopIteration):
self.reset()
raise StopIteration
else: # pragma: no cover
raise item
else:
self._count += 1
return item
[文档]@register.add
class AsyncIter(AsyncLoopIter):
"""
异步装载迭代器
不会进行自动 reset()
"""
def init(self):
super(AsyncIter, self).reset()
def reset(self):
self.queue = queue.Queue(self._size)
super(AsyncIter, self).reset()
def __next__(self):
if self.queue is not None:
item = self.queue.get()
else:
raise StopIteration
if isinstance(item, Exception):
if isinstance(item, StopIteration):
self.thread = None
self.queue = None
self._set_length()
raise StopIteration
else: # pragma: no cover
raise item
else:
self._count += 1
return item
[文档]@register.add
class CacheAsyncLoopIter(AsyncLoopIter):
"""
带缓冲池的异步迭代器,适用于带预处理的文件
自动 reset(),
同时针对 src 为 function 时可能存在的复杂预处理(即异步加载取数据操作比迭代输出数据操作时间长很多),
将异步加载中处理的预处理数据放到指定的缓冲文件中
"""
def __init__(self, src, cache_file, fargs=None, fkwargs=None, rerun=True, tank_size=8, timeout=None, level="t"):
self.cache_file = cache_file
self.cache_queue = None
self.cache_thread = None
if os.path.exists(self.cache_file) and rerun is False:
# 从已有数据中进行装载
src = loading
fargs = [self.cache_file, "jsonl"]
fkwargs = None
self._cache_stop = True
else:
# 重新生成数据
self.cache_queue = queue.Queue(tank_size)
self.cache_thread = threading.Thread(target=self.cached, daemon=False)
self.cache_thread.start()
self._cache_stop = False
super(CacheAsyncLoopIter, self).__init__(src, fargs, fkwargs, tank_size, timeout, level)
def init(self):
super(CacheAsyncLoopIter, self).reset()
def reset(self):
self._reset = loading
self._fargs = [self.cache_file, "jsonl"]
if self.cache_thread is not None:
self.cache_thread.join()
self.cache_thread = None
self.cache_queue = None
super(CacheAsyncLoopIter, self).reset()
def __next__(self):
item = self.queue.get()
if self.cache_queue is not None:
self.cache_queue.put(item)
if isinstance(item, Exception):
if isinstance(item, StopIteration):
self.reset()
raise StopIteration
else: # pragma: no cover
raise item
else:
self._count += 1
return item
def cached(self):
assert self.cache_queue is not None
with wf_open(self.cache_file, mode="w") as wf:
while True:
data = self.cache_queue.get()
if isinstance(data, StopIteration):
break
print(json.dumps(data), file=wf)
[文档]def iterwrap(itertype: str = "AsyncLoopIter", *args, **kwargs):
"""
迭代器装饰器,适用于希望能重复使用一个迭代器的情况,能将迭代器生成函数转换为可以重复使用的函数。
默认使用 AsyncLoopIter。
Examples
--------
.. code-block:: python
@iterwrap()
def open_file():
with open("demo.txt") as f:
for line in f:
yield line
data = open_file()
for _ in range(5):
for line in data:
pass
Warnings
--------
As mentioned in [1], on Windows or MacOS, `spawn()` is the default multiprocessing start method.
Using `spawn()`, another interpreter is launched which runs your main script,
followed by the internal worker function that receives parameters through pickle serialization.
However, `decorator` ,`functools`, `lambda` and local function does not well fit `pickle` like discussed in [2].
Therefore, since version 1.3.36, instead of using `multiprocessing`,
we use `multiprocess` which replace `pickle` with `dill` .
Nevertheless, the users should be aware of that `level='p'` may not work in windows and mac platform
if the decorated function does not follow the `spawn()` behaviour.
Notes
------
Although `fork` in `multiprocessing` is quite easy to use, and iterwrap can work well with it,
the users should still be aware of that `fork` is not safety enough as mentioned in [3].
We use the default mode when deal with `multiprocessing`, i.e., `spawn` in windows and macos, and `folk` in linux.
An example to change the default behaviour is `multiprocessing.set_start_method('spawn')`, which could
be found in [3].
References
----------
[1] https://pytorch.org/docs/stable/data.html#platform-specific-behaviors
[2] https://stackoverflow.com/questions/51867402/cant-\
pickle-function-stringtongrams-at-0x104144f28-its-not-the-same-object
[3] https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
"""
if itertype not in register:
raise TypeError("itertype %s is unknown, the available type are %s" % (itertype, ", ".join(register)))
def _f1(f):
@functools.wraps(f)
def _f2(*fargs, **fkwargs):
return register[itertype](f, fargs=fargs, fkwargs=fkwargs, *args, **kwargs)
return _f2
return _f1