Copy
Ask AI
"""Show how to use multiple tool execution hooks, to run logic before and after a tool is called."""
import asyncio
import json
from inspect import iscoroutinefunction
from typing import Any, Callable, Dict
from agno.agent import Agent
from agno.tools import Toolkit
from agno.utils.log import logger
# ---------------------------------------------------------------------------
# Create Agent
# ---------------------------------------------------------------------------
class CustomerDBTools(Toolkit):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register(self.retrieve_customer_profile)
self.register(self.delete_customer_profile)
def retrieve_customer_profile(self, customer_id: str):
"""
Retrieves a customer profile from the database.
Args:
customer_id: The ID of the customer to retrieve.
Returns:
A string containing the customer profile.
"""
logger.info(f"Looking up customer profile for {customer_id}")
return json.dumps(
{
"customer_id": customer_id,
"name": "John Doe",
"email": "[email protected]",
}
)
def delete_customer_profile(self, customer_id: str):
"""
Deletes a customer profile from the database.
Args:
customer_id: The ID of the customer to delete.
"""
logger.info(f"Deleting customer profile for {customer_id}")
return f"Customer profile for {customer_id}"
def validation_hook(name: str, func: Callable, arguments: Dict[str, Any]):
if name == "retrieve_customer_profile":
cust_id = arguments.get("customer_id")
if cust_id == "123":
raise ValueError("Cannot retrieve customer profile for ID 123")
if name == "delete_customer_profile":
cust_id = arguments.get("customer_id")
if cust_id == "123":
raise ValueError("Cannot delete customer profile for ID 123")
logger.info("Before Validation Hook")
result = func(**arguments)
logger.info("After Validation Hook")
# Remove name from result to sanitize the output
result = json.loads(result)
result.pop("name")
return json.dumps(result)
def logger_hook(name: str, func: Callable, arguments: Dict[str, Any]):
logger.info("Before Logger Hook")
result = func(**arguments)
logger.info("After Logger Hook")
return result
sync_agent = Agent(
tools=[CustomerDBTools()],
# Hooks are executed in order of the list
tool_hooks=[validation_hook, logger_hook],
)
# ---------------------------------------------------------------------------
# Async Variant
# ---------------------------------------------------------------------------
class CustomerDBToolsAsync(Toolkit):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register(self.retrieve_customer_profile)
self.register(self.delete_customer_profile)
async def retrieve_customer_profile(self, customer_id: str):
"""
Retrieves a customer profile from the database.
Args:
customer_id: The ID of the customer to retrieve.
Returns:
A string containing the customer profile.
"""
logger.info(f"Looking up customer profile for {customer_id}")
return json.dumps(
{
"customer_id": customer_id,
"name": "John Doe",
"email": "[email protected]",
}
)
def delete_customer_profile(self, customer_id: str):
"""
Deletes a customer profile from the database.
Args:
customer_id: The ID of the customer to delete.
"""
logger.info(f"Deleting customer profile for {customer_id}")
return f"Customer profile for {customer_id}"
async def validation_hook_async(name: str, func: Callable, arguments: Dict[str, Any]):
if name == "retrieve_customer_profile":
cust_id = arguments.get("customer_id")
if cust_id == "123":
raise ValueError("Cannot retrieve customer profile for ID 123")
if name == "delete_customer_profile":
cust_id = arguments.get("customer_id")
if cust_id == "123":
raise ValueError("Cannot delete customer profile for ID 123")
logger.info("Before Validation Hook")
if iscoroutinefunction(func):
result = await func(**arguments)
else:
result = func(**arguments)
logger.info("After Validation Hook")
# Remove name from result to sanitize the output
if name == "retrieve_customer_profile":
result = json.loads(result)
result.pop("name")
return json.dumps(result)
return result
async def logger_hook_async(name: str, func: Callable, arguments: Dict[str, Any]):
logger.info("Before Logger Hook")
if iscoroutinefunction(func):
result = await func(**arguments)
else:
result = func(**arguments)
logger.info("After Logger Hook")
return result
async_agent = Agent(
tools=[CustomerDBToolsAsync()],
# Hooks are executed in order of the list
tool_hooks=[validation_hook_async, logger_hook_async],
)
# ---------------------------------------------------------------------------
# Run Agent
# ---------------------------------------------------------------------------
if __name__ == "__main__":
sync_agent.print_response("I am customer 456, please retrieve my profile.")
asyncio.run(
async_agent.aprint_response(
"I am customer 456, please retrieve my profile.", stream=True
)
)
asyncio.run(
async_agent.aprint_response(
"I am customer 456, please delete my profile.", stream=True
)
)
Run the Example
Copy
Ask AI
# Clone and setup repo
git clone https://github.com/agno-agi/agno.git
cd agno/cookbook/91_tools/tool_hooks
# Create and activate virtual environment
./scripts/demo_setup.sh
source .venvs/demo/bin/activate
python tool_hooks_in_toolkit_nested.py