4

I'm trying to write a simple asynchronous data batch generator, but having troubles with understanding how to yield from an async for loop. Here I've written a simple class illustrating my idea:

import asyncio
from typing import List

class AsyncSimpleIterator:
    def __init__(self, data: List[str], batch_size=None):
        self.data = data
        self.batch_size = batch_size
        self.doc2index = self.get_doc_ids()

    def get_doc_ids(self):
        return list(range(len(self.data)))

    async def get_batch_data(self, doc_ids):
        print("get_batch_data() running")
        page = [self.data[j] for j in doc_ids]
        return page

    async def get_docs(self, batch_size):
        print("get_docs() running")

        _batch_size = self.batch_size or batch_size
        batches = [self.doc2index[i:i + _batch_size] for i in
                   range(0, len(self.doc2index), _batch_size)]

        for _, doc_ids in enumerate(batches):
            docs = await self.get_batch_data(doc_ids)
            yield docs, doc_ids

    async def main(self):
        print("main() running")
        async for res in self.get_docs(batch_size=2):
            print(res)  # how to yield instead of print?

    def gen_batches(self):
        # how to get results of self.main() here?
        loop = asyncio.get_event_loop()
        loop.run_until_complete(self.main())
        loop.close()


 DATA = ["Hello, world!"] * 4
 iterator = AsyncSimpleIterator(DATA)
 iterator.gen_batches()

So, my question is, how to yield a result from main() to gather it inside gen_batches()?

When I print the result inside main(), I get the following output:

main() running
get_docs() running
get_batch_data() running
(['Hello, world!', 'Hello, world!'], [0, 1])
get_batch_data() running
(['Hello, world!', 'Hello, world!'], [2, 3])
2
  • Is gen_batches() uspposed to exhaust all of main(), or only one iteration? Why do you want to yield from main()? Commented Mar 30, 2018 at 13:25
  • @user4815162342 Yes, the generator in main() should exhaust in gen_batches(), so I can gather all results in gen_batches(). In a real world gen_batches() functionality would be a part of some other class and it should asynchronously get results from main(). Commented Mar 30, 2018 at 14:42

2 Answers 2

1

I'm trying to write a simple asynchronous data batch generator, but having troubles with understanding how to yield from an async for loop

Yielding from an async for works like a regular yield, except that it also has to be collected by an async for or equivalent. For example, the yield in get_docs makes it an async generator. If you replace print(res) with yield res in main(), it will make main() an async generator as well.

the generator in main() should exhaust in gen_batches(), so I can gather all results in gen_batches()

To collect the values produced by an async generator (such as main() with print(res) replaced with yield res), you can use a helper coroutine:

def gen_batches(self):
    loop = asyncio.get_event_loop()
    async def collect():
        return [item async for item in self.main()]
    items = loop.run_until_complete(collect())
    loop.close()
    return items

The collect() helper makes use of a PEP 530 asynchronous comprehension, which can be thought of as syntactic sugar for the more explicit:

    async def collect():
        l = []
        async for item in self.main():
            l.append(item)
        return l
Sign up to request clarification or add additional context in comments.

1 Comment

Thank you! I didn't realize that I can yield from an async for loop only inside another async for loop. I've rewritten the code with your hints and will post it below.
0

A working solution based on @user4815162342 answer to the original question:

import asyncio
from typing import List


class AsyncSimpleIterator:

def __init__(self, data: List[str], batch_size=None):
    self.data = data
    self.batch_size = batch_size
    self.doc2index = self.get_doc_ids()

def get_doc_ids(self):
    return list(range(len(self.data)))

async def get_batch_data(self, doc_ids):
    print("get_batch_data() running")
    page = [self.data[j] for j in doc_ids]
    return page

async def get_docs(self, batch_size):
    print("get_docs() running")

    _batch_size = self.batch_size or batch_size
    batches = [self.doc2index[i:i + _batch_size] for i in
               range(0, len(self.doc2index), _batch_size)]

    for _, doc_ids in enumerate(batches):
        docs = await self.get_batch_data(doc_ids)
        yield docs, doc_ids

def gen_batches(self):
    loop = asyncio.get_event_loop()

    async def collect():
        return [j async for j in self.get_docs(batch_size=2)]

    items = loop.run_until_complete(collect())
    loop.close()
    return items


DATA = ["Hello, world!"] * 4
iterator = AsyncSimpleIterator(DATA)
result = iterator.gen_batches()
print(result)

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.