Spaces:
Paused
Paused
| import sys | |
| import os | |
| import io, asyncio | |
| import json | |
| import pytest | |
| import time | |
| from litellm import mock_completion | |
| from unittest.mock import MagicMock, AsyncMock, patch | |
| sys.path.insert(0, os.path.abspath("../..")) | |
| import litellm | |
| from litellm.proxy.guardrails.guardrail_hooks.presidio import _OPTIONAL_PresidioPIIMasking, PresidioPerRequestConfig | |
| from litellm.integrations.custom_logger import CustomLogger | |
| from litellm.types.utils import StandardLoggingPayload, StandardLoggingGuardrailInformation | |
| from litellm.types.guardrails import GuardrailEventHooks | |
| from typing import Optional | |
| class TestCustomLogger(CustomLogger): | |
| def __init__(self, *args, **kwargs): | |
| self.standard_logging_payload: Optional[StandardLoggingPayload] = None | |
| async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): | |
| self.standard_logging_payload = kwargs.get("standard_logging_object") | |
| pass | |
| async def test_standard_logging_payload_includes_guardrail_information(): | |
| """ | |
| Test that the standard logging payload includes the guardrail information when a guardrail is applied | |
| """ | |
| test_custom_logger = TestCustomLogger() | |
| litellm.callbacks = [test_custom_logger] | |
| presidio_guard = _OPTIONAL_PresidioPIIMasking( | |
| guardrail_name="presidio_guard", | |
| event_hook=GuardrailEventHooks.pre_call, | |
| presidio_analyzer_api_base=os.getenv("PRESIDIO_ANALYZER_API_BASE"), | |
| presidio_anonymizer_api_base=os.getenv("PRESIDIO_ANONYMIZER_API_BASE"), | |
| ) | |
| # 1. call the pre call hook with guardrail | |
| request_data = { | |
| "model": "gpt-4o", | |
| "messages": [ | |
| {"role": "user", "content": "Hello, my phone number is +1 412 555 1212"}, | |
| ], | |
| "mock_response": "Hello", | |
| "guardrails": ["presidio_guard"], | |
| "metadata": {}, | |
| } | |
| await presidio_guard.async_pre_call_hook( | |
| user_api_key_dict={}, | |
| cache=None, | |
| data=request_data, | |
| call_type="acompletion" | |
| ) | |
| # 2. call litellm.acompletion | |
| response = await litellm.acompletion(**request_data) | |
| # 3. assert that the standard logging payload includes the guardrail information | |
| await asyncio.sleep(1) | |
| print("got standard logging payload=", json.dumps(test_custom_logger.standard_logging_payload, indent=4, default=str)) | |
| assert test_custom_logger.standard_logging_payload is not None | |
| assert test_custom_logger.standard_logging_payload["guardrail_information"] is not None | |
| assert test_custom_logger.standard_logging_payload["guardrail_information"]["guardrail_name"] == "presidio_guard" | |
| assert test_custom_logger.standard_logging_payload["guardrail_information"]["guardrail_mode"] == GuardrailEventHooks.pre_call | |
| # assert that the guardrail_response is a response from presidio analyze | |
| presidio_response = test_custom_logger.standard_logging_payload["guardrail_information"]["guardrail_response"] | |
| assert isinstance(presidio_response, list) | |
| for response_item in presidio_response: | |
| assert "analysis_explanation" in response_item | |
| assert "start" in response_item | |
| assert "end" in response_item | |
| assert "score" in response_item | |
| assert "entity_type" in response_item | |
| assert "recognition_metadata" in response_item | |
| # assert that the duration is not None | |
| assert test_custom_logger.standard_logging_payload["guardrail_information"]["duration"] is not None | |
| assert test_custom_logger.standard_logging_payload["guardrail_information"]["duration"] > 0 | |
| # assert that we get the count of masked entities | |
| assert test_custom_logger.standard_logging_payload["guardrail_information"]["masked_entity_count"] is not None | |
| assert test_custom_logger.standard_logging_payload["guardrail_information"]["masked_entity_count"]["PHONE_NUMBER"] == 1 | |
| async def test_langfuse_trace_includes_guardrail_information(): | |
| """ | |
| Test that the langfuse trace includes the guardrail information when a guardrail is applied | |
| """ | |
| import httpx | |
| from unittest.mock import AsyncMock, patch | |
| from litellm.integrations.langfuse.langfuse_prompt_management import LangfusePromptManagement | |
| callback = LangfusePromptManagement(flush_interval=3) | |
| import json | |
| # Create a mock Response object | |
| mock_response = AsyncMock(spec=httpx.Response) | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = {"status": "success"} | |
| # Create mock for httpx.Client.post | |
| mock_post = AsyncMock() | |
| mock_post.return_value = mock_response | |
| with patch("httpx.Client.post", mock_post): | |
| litellm._turn_on_debug() | |
| litellm.callbacks = [callback] | |
| presidio_guard = _OPTIONAL_PresidioPIIMasking( | |
| guardrail_name="presidio_guard", | |
| event_hook=GuardrailEventHooks.pre_call, | |
| presidio_analyzer_api_base=os.getenv("PRESIDIO_ANALYZER_API_BASE"), | |
| presidio_anonymizer_api_base=os.getenv("PRESIDIO_ANONYMIZER_API_BASE"), | |
| ) | |
| # 1. call the pre call hook with guardrail | |
| request_data = { | |
| "model": "gpt-4o", | |
| "messages": [ | |
| {"role": "user", "content": "Hello, my phone number is +1 412 555 1212"}, | |
| ], | |
| "mock_response": "Hello", | |
| "guardrails": ["presidio_guard"], | |
| "metadata": {}, | |
| } | |
| await presidio_guard.async_pre_call_hook( | |
| user_api_key_dict={}, | |
| cache=None, | |
| data=request_data, | |
| call_type="acompletion" | |
| ) | |
| # 2. call litellm.acompletion | |
| response = await litellm.acompletion(**request_data) | |
| # 3. Wait for async logging operations to complete | |
| await asyncio.sleep(5) | |
| # 4. Verify the Langfuse payload | |
| assert mock_post.call_count >= 1 | |
| url = mock_post.call_args[0][0] | |
| request_body = mock_post.call_args[1].get("content") | |
| # Parse the JSON body | |
| actual_payload = json.loads(request_body) | |
| print("\nLangfuse payload:", json.dumps(actual_payload, indent=2)) | |
| # Look for the guardrail span in the payload | |
| guardrail_span = None | |
| for item in actual_payload["batch"]: | |
| if (item["type"] == "span-create" and | |
| item["body"].get("name") == "guardrail"): | |
| guardrail_span = item | |
| break | |
| # Assert that the guardrail span exists | |
| assert guardrail_span is not None, "No guardrail span found in Langfuse payload" | |
| # Validate the structure of the guardrail span | |
| assert guardrail_span["body"]["name"] == "guardrail" | |
| assert "metadata" in guardrail_span["body"] | |
| assert guardrail_span["body"]["metadata"]["guardrail_name"] == "presidio_guard" | |
| assert guardrail_span["body"]["metadata"]["guardrail_mode"] == GuardrailEventHooks.pre_call | |
| assert "guardrail_masked_entity_count" in guardrail_span["body"]["metadata"] | |
| assert guardrail_span["body"]["metadata"]["guardrail_masked_entity_count"]["PHONE_NUMBER"] == 1 | |
| # Validate the output format matches the expected structure | |
| assert "output" in guardrail_span["body"] | |
| assert isinstance(guardrail_span["body"]["output"], list) | |
| assert len(guardrail_span["body"]["output"]) > 0 | |
| # Validate the first output item has the expected structure | |
| output_item = guardrail_span["body"]["output"][0] | |
| assert "entity_type" in output_item | |
| assert output_item["entity_type"] == "PHONE_NUMBER" | |
| assert "score" in output_item | |
| assert "start" in output_item | |
| assert "end" in output_item | |
| assert "recognition_metadata" in output_item | |
| assert "recognizer_name" in output_item["recognition_metadata"] | |