Hi everyone,
I am trying to build the following node:
node( func=lambda x: None, name="test-agent", inputs="master:claude-sonnet::worker:mistral-large", outputs=None )
from kedro import io class DataCatalog(io.DataCatalog): def load(self, name: str, version: str | None = None) -> Any: if "::" in name: # We have agent-llm design params = {} for n in name.split("::"): param, ds = n.split(":") params[param] = super().load(name=ds, version=version) return AgentLLM(**params) else: return super().load(name=name, version=version)
claude-sonnet
and mistral-large
defined in my catalog, along 10 other different models. I wanted to customize the behavior to have multiple models passed into a new class."master:claude-sonnet::worker:mistral-large"
which creates a AgentLLM class with master and worker parameter. Unfortunately this doesn't work as I expect since I am hitting the following error in runner code before it even reaches to my catalog function:ValueError: Pipeline input(s) {'master:claude-sonnet::worker:mistral-large'} not found in the DataCatalog
Okay so I made it work by following additions π
class DataCatalog(io.DataCatalog): def __contains__(self, dataset_name: str) -> bool: if "::" in dataset_name: # This condition is added for runner checks to verify existing datasets ds_checks = [n.split(":")[1] for n in dataset_name.split("::")] return all([d in self for d in ds_checks]) else: return super().__contains__(dataset_name) def _get_dataset( self, dataset_name: str, version: Version | None = None, suggest: bool = True, ) -> AbstractDataset: if "::" in dataset_name: # This condition is added for runner checks, return type does not matter # as it is ignored. datasets = [n.split(":")[1] for n in dataset_name.split("::")] return [ self._get_dataset(dataset_name=ds, version=version, suggest=suggest) for ds in datasets ] else: return super()._get_dataset( dataset_name=dataset_name, version=version, suggest=suggest ) def load(self, name: str, version: str | None = None) -> Any: if "::" in name: # We have agent-llm design params = {} for n in name.split("::"): param, ds = n.split(":") params[param] = super().load(name=ds, version=version) return AgentLLM(**params) else: return super().load(name=name, version=version)