|
|
""" |
|
|
Test suite for bird classifier agents. |
|
|
""" |
|
|
import asyncio |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
parent_dir = Path(__file__).parent.parent |
|
|
if str(parent_dir) not in sys.path: |
|
|
sys.path.insert(0, str(parent_dir)) |
|
|
|
|
|
from langgraph_agent import AgentFactory |
|
|
|
|
|
|
|
|
async def test_classifier_agent(): |
|
|
"""Test basic classifier agent with multiple images.""" |
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("Test Suite: Basic Classifier Agent") |
|
|
print("="*70 + "\n") |
|
|
|
|
|
|
|
|
agent = await AgentFactory.create_classifier_agent() |
|
|
|
|
|
test_urls = [ |
|
|
"https://images.unsplash.com/photo-1555169062-013468b47731?w=400", |
|
|
"https://images.unsplash.com/photo-1445820200644-69f87d946277?w=400", |
|
|
] |
|
|
|
|
|
for i, url in enumerate(test_urls, 1): |
|
|
print(f"\n[TEST {i}/{len(test_urls)}]") |
|
|
print("="*70) |
|
|
|
|
|
result = await agent.ainvoke({ |
|
|
"messages": [{ |
|
|
"role": "user", |
|
|
"content": f"Classify the bird in this image: {url}" |
|
|
}] |
|
|
}) |
|
|
|
|
|
print(f"\n[RESULT]: {result['messages'][-1].content}\n") |
|
|
|
|
|
print("\n[ALL TESTS COMPLETE!]\n") |
|
|
|
|
|
|
|
|
async def test_multi_server_agent(): |
|
|
"""Test multi-server agent with classifier + eBird.""" |
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("Test Suite: Multi-Server Agent") |
|
|
print("="*70 + "\n") |
|
|
|
|
|
|
|
|
agent = await AgentFactory.create_multi_server_agent(with_memory=True) |
|
|
config = {"configurable": {"thread_id": "test_session"}} |
|
|
|
|
|
|
|
|
print("\n[TEST 1]: Classify bird from URL") |
|
|
print("="*70) |
|
|
result1 = await agent.ainvoke({ |
|
|
"messages": [{ |
|
|
"role": "user", |
|
|
"content": "What bird is this? https://images.unsplash.com/photo-1555169062-013468b47731?w=400" |
|
|
}] |
|
|
}, config) |
|
|
print(f"\n[RESULT]: {result1['messages'][-1].content}\n") |
|
|
|
|
|
|
|
|
print("\n[TEST 2]: Follow-up question (tests memory)") |
|
|
print("="*70) |
|
|
result2 = await agent.ainvoke({ |
|
|
"messages": [{ |
|
|
"role": "user", |
|
|
"content": "Where can I see this bird near Boston (42.36, -71.06)?" |
|
|
}] |
|
|
}, config) |
|
|
print(f"\n[RESULT]: {result2['messages'][-1].content}\n") |
|
|
|
|
|
print("\n[ALL TESTS COMPLETE!]\n") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import sys |
|
|
|
|
|
if len(sys.argv) > 1 and sys.argv[1] == "multi": |
|
|
asyncio.run(test_multi_server_agent()) |
|
|
else: |
|
|
asyncio.run(test_classifier_agent()) |