Join the Kedro community

Updated last month

Customizing Data Catalog Loader to Handle Multiple Models

At a glance

Hi everyone,

I am trying to build the following node:

node(
    func=lambda x: None,
    name="test-agent",
    inputs="master:claude-sonnet::worker:mistral-large",
    outputs=None
)

Here I customized my data catalog loader as follows:

from kedro import io


class DataCatalog(io.DataCatalog):
    def load(self, name: str, version: str | None = None) -> Any:
        if "::" in name:    # We have agent-llm design
            params = {}

            for n in name.split("::"):
                param, ds = n.split(":")
                params[param] = super().load(name=ds, version=version)

            return AgentLLM(**params)

        else:
            return super().load(name=name, version=version)

Essentially I have different models defined in my catalog. So claude-sonnet and mistral-large defined in my catalog, along 10 other different models. I wanted to customize the behavior to have multiple models passed into a new class.

Now I was just passing them one by one for now but looks like i have cases where I need to pass 3-4 different models and create intermediate classes which uses them under the hood. I wanted to improve this setup by having the following new syntax like "master:claude-sonnet::worker:mistral-large" which creates a AgentLLM class with master and worker parameter. Unfortunately this doesn't work as I expect since I am hitting the following error in runner code before it even reaches to my catalog function:

ValueError: Pipeline input(s) {'master:claude-sonnet::worker:mistral-large'} not found in the DataCatalog

What would be a good way to overcome this behaviour?

F
1 comment

Okay so I made it work by following additions πŸ˜„

class DataCatalog(io.DataCatalog):
    def __contains__(self, dataset_name: str) -> bool:
        if "::" in dataset_name:
            # This condition is added for runner checks to verify existing datasets
            ds_checks = [n.split(":")[1] for n in dataset_name.split("::")]
            return all([d in self for d in ds_checks])
        else:
            return super().__contains__(dataset_name)

    def _get_dataset(
        self,
        dataset_name: str,
        version: Version | None = None,
        suggest: bool = True,
    ) -> AbstractDataset:
        if "::" in dataset_name:
            # This condition is added for runner checks, return type does not matter
            # as it is ignored.
            datasets = [n.split(":")[1] for n in dataset_name.split("::")]
            return [
                self._get_dataset(dataset_name=ds, version=version, suggest=suggest)
                for ds in datasets
            ]
        else:
            return super()._get_dataset(
                dataset_name=dataset_name, version=version, suggest=suggest
            )

    def load(self, name: str, version: str | None = None) -> Any:
        if "::" in name:    # We have agent-llm design
            params = {}

            for n in name.split("::"):
                param, ds = n.split(":")
                params[param] = super().load(name=ds, version=version)

            return AgentLLM(**params)

        else:
            return super().load(name=name, version=version)

Add a reply
Sign up and join the conversation on Slack