Skip to main content
"""Show how to use multiple tool execution hooks, to run logic before and after a tool is called."""

import asyncio
import json
from inspect import iscoroutinefunction
from typing import Any, Callable, Dict

from agno.agent import Agent
from agno.tools import Toolkit
from agno.utils.log import logger

# ---------------------------------------------------------------------------
# Create Agent
# ---------------------------------------------------------------------------


class CustomerDBTools(Toolkit):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.register(self.retrieve_customer_profile)
        self.register(self.delete_customer_profile)

    def retrieve_customer_profile(self, customer_id: str):
        """
        Retrieves a customer profile from the database.

        Args:
            customer_id: The ID of the customer to retrieve.

        Returns:
            A string containing the customer profile.
        """
        logger.info(f"Looking up customer profile for {customer_id}")
        return json.dumps(
            {
                "customer_id": customer_id,
                "name": "John Doe",
                "email": "[email protected]",
            }
        )

    def delete_customer_profile(self, customer_id: str):
        """
        Deletes a customer profile from the database.

        Args:
            customer_id: The ID of the customer to delete.
        """
        logger.info(f"Deleting customer profile for {customer_id}")
        return f"Customer profile for {customer_id}"


def validation_hook(name: str, func: Callable, arguments: Dict[str, Any]):
    if name == "retrieve_customer_profile":
        cust_id = arguments.get("customer_id")
        if cust_id == "123":
            raise ValueError("Cannot retrieve customer profile for ID 123")

    if name == "delete_customer_profile":
        cust_id = arguments.get("customer_id")
        if cust_id == "123":
            raise ValueError("Cannot delete customer profile for ID 123")

    logger.info("Before Validation Hook")
    result = func(**arguments)
    logger.info("After Validation Hook")
    # Remove name from result to sanitize the output
    result = json.loads(result)
    result.pop("name")
    return json.dumps(result)


def logger_hook(name: str, func: Callable, arguments: Dict[str, Any]):
    logger.info("Before Logger Hook")
    result = func(**arguments)
    logger.info("After Logger Hook")
    return result


sync_agent = Agent(
    tools=[CustomerDBTools()],
    # Hooks are executed in order of the list
    tool_hooks=[validation_hook, logger_hook],
)


# ---------------------------------------------------------------------------
# Async Variant
# ---------------------------------------------------------------------------


class CustomerDBToolsAsync(Toolkit):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.register(self.retrieve_customer_profile)
        self.register(self.delete_customer_profile)

    async def retrieve_customer_profile(self, customer_id: str):
        """
        Retrieves a customer profile from the database.

        Args:
            customer_id: The ID of the customer to retrieve.

        Returns:
            A string containing the customer profile.
        """
        logger.info(f"Looking up customer profile for {customer_id}")
        return json.dumps(
            {
                "customer_id": customer_id,
                "name": "John Doe",
                "email": "[email protected]",
            }
        )

    def delete_customer_profile(self, customer_id: str):
        """
        Deletes a customer profile from the database.

        Args:
            customer_id: The ID of the customer to delete.
        """
        logger.info(f"Deleting customer profile for {customer_id}")
        return f"Customer profile for {customer_id}"


async def validation_hook_async(name: str, func: Callable, arguments: Dict[str, Any]):
    if name == "retrieve_customer_profile":
        cust_id = arguments.get("customer_id")
        if cust_id == "123":
            raise ValueError("Cannot retrieve customer profile for ID 123")

    if name == "delete_customer_profile":
        cust_id = arguments.get("customer_id")
        if cust_id == "123":
            raise ValueError("Cannot delete customer profile for ID 123")

    logger.info("Before Validation Hook")
    if iscoroutinefunction(func):
        result = await func(**arguments)
    else:
        result = func(**arguments)
    logger.info("After Validation Hook")
    # Remove name from result to sanitize the output
    if name == "retrieve_customer_profile":
        result = json.loads(result)
        result.pop("name")
        return json.dumps(result)
    return result


async def logger_hook_async(name: str, func: Callable, arguments: Dict[str, Any]):
    logger.info("Before Logger Hook")
    if iscoroutinefunction(func):
        result = await func(**arguments)
    else:
        result = func(**arguments)
    logger.info("After Logger Hook")
    return result


async_agent = Agent(
    tools=[CustomerDBToolsAsync()],
    # Hooks are executed in order of the list
    tool_hooks=[validation_hook_async, logger_hook_async],
)


# ---------------------------------------------------------------------------
# Run Agent
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    sync_agent.print_response("I am customer 456, please retrieve my profile.")
    asyncio.run(
        async_agent.aprint_response(
            "I am customer 456, please retrieve my profile.", stream=True
        )
    )
    asyncio.run(
        async_agent.aprint_response(
            "I am customer 456, please delete my profile.", stream=True
        )
    )

Run the Example

# Clone and setup repo
git clone https://github.com/agno-agi/agno.git
cd agno/cookbook/91_tools/tool_hooks

# Create and activate virtual environment
./scripts/demo_setup.sh
source .venvs/demo/bin/activate

python tool_hooks_in_toolkit_nested.py