Join the Kedro community

Updated last month

Can async functions be passed to nodes

At a glance

Hey team, is there any way to pass async functions to nodes?

1
d
J
A
11 comments

I don't think we've ever had this question before

if you do it what error do you get?

I think you need to do introduce a custom runner which you can do with

kedro run --runner class.path.of.your.AsyncRunner 

ChatGPT has suggested this which feels sensible, essentially awaiting certain functions rather than just executing them:

import asyncio

class SequentialAsyncRunner(AbstractRunner):
    """``SequentialAsyncRunner`` is an ``AbstractRunner`` implementation that 
    can be used to run the ``Pipeline`` asynchronously.
    """

    async def _run(
        self,
        pipeline: Pipeline,
        catalog: CatalogProtocol,
        hook_manager: PluginManager,
        session_id: str | None = None,
    ) -> None:
        """The method implementing sequential pipeline running asynchronously.

        Args:
            pipeline: The ``Pipeline`` to run.
            catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data.
            hook_manager: The ``PluginManager`` to activate hooks.
            session_id: The id of the session.

        Raises:
            Exception: in case of any downstream node failure.
        """
        if not self._is_async:
            self._logger.info(
                "Using synchronous mode for loading and saving data. Use the --async flag "
                "for potential performance gains. <a target="_blank" rel="noopener noreferrer" href="https://docs.kedro.org/en/stable/nodes_and_pipelines/run_a_pipeline.html#load-and-save-asynchronously">https://docs.kedro.org/en/stable/nodes_and_pipelines/run_a_pipeline.html#load-and-save-asynchronously</a>"
            )

        nodes = pipeline.nodes
        done_nodes = set()

        load_counts = Counter(chain.from_iterable(n.inputs for n in nodes))

        for exec_index, node in enumerate(nodes):
            try:
                if asyncio.iscoroutinefunction(node.func):
                    # Await the async function
                    await run_node(node, catalog, hook_manager, self._is_async, session_id)
                else:
                    # Run as normal if the function is synchronous
                    run_node(node, catalog, hook_manager, self._is_async, session_id)
                
                done_nodes.add(node)
            except Exception:
                self._suggest_resume_scenario(pipeline, done_nodes, catalog)
                raise

            # Decrement load counts and release any datasets we've finished with
            for dataset in node.inputs:
                load_counts[dataset] -= 1
                if load_counts[dataset] < 1 and dataset not in pipeline.inputs():
                    catalog.release(dataset)
            for dataset in node.outputs:
                if load_counts[dataset] < 1 and dataset not in pipeline.outputs():
                    catalog.release(dataset)

            self._logger.info(
                "Completed %d out of %d tasks", exec_index + 1, len(nodes)
            )

did that work for you?

I found a nice workaround. I now have async functions and call them e.g., aprocess_data(). I then have a wrapper function called process_data() with the same parameters. I call process_data from my pipeline and aprocess_data from my notebook

process_data calls asyncio.run(aprocess_data())

Nice! That’s sensible, thanks for feeding back to the community :kedroid:

hi ! You've raise an interesting point. Do you have some example of how you utilise your "workaround"?

I use it to parallelize a slow LLM function that runs for every row in my table

Add a reply
Sign up and join the conversation on Slack