Spaces:
Paused
Paused
| """ | |
| Transformation for Bedrock Invoke Agent | |
| https://docs.aws.amazon.com/bedrock/latest/APIReference/API_agent-runtime_InvokeAgent.html | |
| """ | |
| import base64 | |
| import json | |
| import uuid | |
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union | |
| import httpx | |
| from litellm._logging import verbose_logger | |
| from litellm.litellm_core_utils.prompt_templates.common_utils import ( | |
| convert_content_list_to_str, | |
| ) | |
| from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException | |
| from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM | |
| from litellm.llms.bedrock.common_utils import BedrockError | |
| from litellm.types.llms.bedrock_invoke_agents import ( | |
| InvokeAgentChunkPayload, | |
| InvokeAgentEvent, | |
| InvokeAgentEventHeaders, | |
| InvokeAgentEventList, | |
| InvokeAgentTrace, | |
| InvokeAgentTracePayload, | |
| InvokeAgentUsage, | |
| ) | |
| from litellm.types.llms.openai import AllMessageValues | |
| from litellm.types.utils import Choices, Message, ModelResponse | |
| if TYPE_CHECKING: | |
| from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj | |
| LiteLLMLoggingObj = _LiteLLMLoggingObj | |
| else: | |
| LiteLLMLoggingObj = Any | |
| class AmazonInvokeAgentConfig(BaseConfig, BaseAWSLLM): | |
| def __init__(self, **kwargs): | |
| BaseConfig.__init__(self, **kwargs) | |
| BaseAWSLLM.__init__(self, **kwargs) | |
| def get_supported_openai_params(self, model: str) -> List[str]: | |
| """ | |
| This is a base invoke agent model mapping. For Invoke Agent - define a bedrock provider specific config that extends this class. | |
| Bedrock Invoke Agents has 0 OpenAI compatible params | |
| As of May 29th, 2025 - they don't support streaming. | |
| """ | |
| return [] | |
| def map_openai_params( | |
| self, | |
| non_default_params: dict, | |
| optional_params: dict, | |
| model: str, | |
| drop_params: bool, | |
| ) -> dict: | |
| """ | |
| This is a base invoke agent model mapping. For Invoke Agent - define a bedrock provider specific config that extends this class. | |
| """ | |
| return optional_params | |
| def get_complete_url( | |
| self, | |
| api_base: Optional[str], | |
| api_key: Optional[str], | |
| model: str, | |
| optional_params: dict, | |
| litellm_params: dict, | |
| stream: Optional[bool] = None, | |
| ) -> str: | |
| """ | |
| Get the complete url for the request | |
| """ | |
| ### SET RUNTIME ENDPOINT ### | |
| aws_bedrock_runtime_endpoint = optional_params.get( | |
| "aws_bedrock_runtime_endpoint", None | |
| ) # https://bedrock-runtime.{region_name}.amazonaws.com | |
| endpoint_url, _ = self.get_runtime_endpoint( | |
| api_base=api_base, | |
| aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, | |
| aws_region_name=self._get_aws_region_name( | |
| optional_params=optional_params, model=model | |
| ), | |
| endpoint_type="agent", | |
| ) | |
| agent_id, agent_alias_id = self._get_agent_id_and_alias_id(model) | |
| session_id = self._get_session_id(optional_params) | |
| endpoint_url = f"{endpoint_url}/agents/{agent_id}/agentAliases/{agent_alias_id}/sessions/{session_id}/text" | |
| return endpoint_url | |
| def sign_request( | |
| self, | |
| headers: dict, | |
| optional_params: dict, | |
| request_data: dict, | |
| api_base: str, | |
| model: Optional[str] = None, | |
| stream: Optional[bool] = None, | |
| fake_stream: Optional[bool] = None, | |
| ) -> Tuple[dict, Optional[bytes]]: | |
| return self._sign_request( | |
| service_name="bedrock", | |
| headers=headers, | |
| optional_params=optional_params, | |
| request_data=request_data, | |
| api_base=api_base, | |
| model=model, | |
| stream=stream, | |
| fake_stream=fake_stream, | |
| ) | |
| def _get_agent_id_and_alias_id(self, model: str) -> tuple[str, str]: | |
| """ | |
| model = "agent/L1RT58GYRW/MFPSBCXYTW" | |
| agent_id = "L1RT58GYRW" | |
| agent_alias_id = "MFPSBCXYTW" | |
| """ | |
| # Split the model string by '/' and extract components | |
| parts = model.split("/") | |
| if len(parts) != 3 or parts[0] != "agent": | |
| raise ValueError( | |
| "Invalid model format. Expected format: 'model=agent/AGENT_ID/ALIAS_ID'" | |
| ) | |
| return parts[1], parts[2] # Return (agent_id, agent_alias_id) | |
| def _get_session_id(self, optional_params: dict) -> str: | |
| """ """ | |
| return optional_params.get("sessionID", None) or str(uuid.uuid4()) | |
| def transform_request( | |
| self, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| headers: dict, | |
| ) -> dict: | |
| # use the last message content as the query | |
| query: str = convert_content_list_to_str(messages[-1]) | |
| return { | |
| "inputText": query, | |
| "enableTrace": True, | |
| **optional_params, | |
| } | |
| def _parse_aws_event_stream(self, raw_content: bytes) -> InvokeAgentEventList: | |
| """ | |
| Parse AWS event stream format using boto3/botocore's built-in parser. | |
| This is the same approach used in the existing AWSEventStreamDecoder. | |
| """ | |
| try: | |
| from botocore.eventstream import EventStreamBuffer | |
| from botocore.parsers import EventStreamJSONParser | |
| except ImportError: | |
| raise ImportError("boto3/botocore is required for AWS event stream parsing") | |
| events: InvokeAgentEventList = [] | |
| parser = EventStreamJSONParser() | |
| event_stream_buffer = EventStreamBuffer() | |
| # Add the entire response to the buffer | |
| event_stream_buffer.add_data(raw_content) | |
| # Process all events in the buffer | |
| for event in event_stream_buffer: | |
| try: | |
| headers = self._extract_headers_from_event(event) | |
| event_type = headers.get("event_type", "") | |
| if event_type == "chunk": | |
| # Handle chunk events specially - they contain decoded content, not JSON | |
| message = self._parse_message_from_event(event, parser) | |
| parsed_event: InvokeAgentEvent = InvokeAgentEvent() | |
| if message: | |
| # For chunk events, create a payload with the decoded content | |
| parsed_event = { | |
| "headers": headers, | |
| "payload": { | |
| "bytes": base64.b64encode( | |
| message.encode("utf-8") | |
| ).decode("utf-8") | |
| }, # Re-encode for consistency | |
| } | |
| events.append(parsed_event) | |
| elif event_type == "trace": | |
| # Handle trace events normally - they contain JSON | |
| message = self._parse_message_from_event(event, parser) | |
| if message: | |
| try: | |
| event_data = json.loads(message) | |
| parsed_event = { | |
| "headers": headers, | |
| "payload": event_data, | |
| } | |
| events.append(parsed_event) | |
| except json.JSONDecodeError as e: | |
| verbose_logger.warning( | |
| f"Failed to parse trace event JSON: {e}" | |
| ) | |
| else: | |
| verbose_logger.debug(f"Unknown event type: {event_type}") | |
| except Exception as e: | |
| verbose_logger.error(f"Error processing event: {e}") | |
| continue | |
| return events | |
| def _parse_message_from_event(self, event, parser) -> Optional[str]: | |
| """Extract message content from an AWS event, adapted from AWSEventStreamDecoder.""" | |
| try: | |
| response_dict = event.to_response_dict() | |
| verbose_logger.debug(f"Response dict: {response_dict}") | |
| # Use the same response shape parsing as the existing decoder | |
| parsed_response = parser.parse( | |
| response_dict, self._get_response_stream_shape() | |
| ) | |
| verbose_logger.debug(f"Parsed response: {parsed_response}") | |
| if response_dict["status_code"] != 200: | |
| decoded_body = response_dict["body"].decode() | |
| if isinstance(decoded_body, dict): | |
| error_message = decoded_body.get("message") | |
| elif isinstance(decoded_body, str): | |
| error_message = decoded_body | |
| else: | |
| error_message = "" | |
| exception_status = response_dict["headers"].get(":exception-type") | |
| error_message = exception_status + " " + error_message | |
| raise BedrockError( | |
| status_code=response_dict["status_code"], | |
| message=( | |
| json.dumps(error_message) | |
| if isinstance(error_message, dict) | |
| else error_message | |
| ), | |
| ) | |
| if "chunk" in parsed_response: | |
| chunk = parsed_response.get("chunk") | |
| if not chunk: | |
| return None | |
| return chunk.get("bytes").decode() | |
| else: | |
| chunk = response_dict.get("body") | |
| if not chunk: | |
| return None | |
| return chunk.decode() | |
| except Exception as e: | |
| verbose_logger.debug(f"Error parsing message from event: {e}") | |
| return None | |
| def _extract_headers_from_event(self, event) -> InvokeAgentEventHeaders: | |
| """Extract headers from an AWS event for categorization.""" | |
| try: | |
| response_dict = event.to_response_dict() | |
| headers = response_dict.get("headers", {}) | |
| # Extract the event-type and content-type headers that we care about | |
| return InvokeAgentEventHeaders( | |
| event_type=headers.get(":event-type", ""), | |
| content_type=headers.get(":content-type", ""), | |
| message_type=headers.get(":message-type", ""), | |
| ) | |
| except Exception as e: | |
| verbose_logger.debug(f"Error extracting headers: {e}") | |
| return InvokeAgentEventHeaders( | |
| event_type="", content_type="", message_type="" | |
| ) | |
| def _get_response_stream_shape(self): | |
| """Get the response stream shape for parsing, reusing existing logic.""" | |
| try: | |
| # Try to reuse the cached shape from the existing decoder | |
| from litellm.llms.bedrock.chat.invoke_handler import ( | |
| get_response_stream_shape, | |
| ) | |
| return get_response_stream_shape() | |
| except ImportError: | |
| # Fallback: create our own shape | |
| try: | |
| from botocore.loaders import Loader | |
| from botocore.model import ServiceModel | |
| loader = Loader() | |
| bedrock_service_dict = loader.load_service_model( | |
| "bedrock-runtime", "service-2" | |
| ) | |
| bedrock_service_model = ServiceModel(bedrock_service_dict) | |
| return bedrock_service_model.shape_for("ResponseStream") | |
| except Exception as e: | |
| verbose_logger.warning(f"Could not load response stream shape: {e}") | |
| return None | |
| def _extract_response_content(self, events: InvokeAgentEventList) -> str: | |
| """Extract the final response content from parsed events.""" | |
| response_parts = [] | |
| for event in events: | |
| headers = event.get("headers", {}) | |
| payload = event.get("payload") | |
| event_type = headers.get( | |
| "event_type" | |
| ) # Note: using event_type not event-type | |
| if event_type == "chunk" and payload: | |
| # Extract base64 encoded content from chunk events | |
| chunk_payload: InvokeAgentChunkPayload = payload # type: ignore | |
| encoded_bytes = chunk_payload.get("bytes", "") | |
| if encoded_bytes: | |
| try: | |
| decoded_content = base64.b64decode(encoded_bytes).decode( | |
| "utf-8" | |
| ) | |
| response_parts.append(decoded_content) | |
| except Exception as e: | |
| verbose_logger.warning(f"Failed to decode chunk content: {e}") | |
| return "".join(response_parts) | |
| def _extract_usage_info(self, events: InvokeAgentEventList) -> InvokeAgentUsage: | |
| """Extract token usage information from trace events.""" | |
| usage_info = InvokeAgentUsage( | |
| inputTokens=0, | |
| outputTokens=0, | |
| model=None, | |
| ) | |
| response_model: Optional[str] = None | |
| for event in events: | |
| if not self._is_trace_event(event): | |
| continue | |
| trace_data = self._get_trace_data(event) | |
| if not trace_data: | |
| continue | |
| verbose_logger.debug(f"Trace event: {trace_data}") | |
| # Extract usage from pre-processing trace | |
| self._extract_and_update_preprocessing_usage( | |
| trace_data=trace_data, | |
| usage_info=usage_info, | |
| ) | |
| # Extract model from orchestration trace | |
| if response_model is None: | |
| response_model = self._extract_orchestration_model(trace_data) | |
| usage_info["model"] = response_model | |
| return usage_info | |
| def _is_trace_event(self, event: InvokeAgentEvent) -> bool: | |
| """Check if the event is a trace event.""" | |
| headers = event.get("headers", {}) | |
| event_type = headers.get("event_type") | |
| payload = event.get("payload") | |
| return event_type == "trace" and payload is not None | |
| def _get_trace_data(self, event: InvokeAgentEvent) -> Optional[InvokeAgentTrace]: | |
| """Extract trace data from a trace event.""" | |
| payload = event.get("payload") | |
| if not payload: | |
| return None | |
| trace_payload: InvokeAgentTracePayload = payload # type: ignore | |
| return trace_payload.get("trace", {}) | |
| def _extract_and_update_preprocessing_usage( | |
| self, trace_data: InvokeAgentTrace, usage_info: InvokeAgentUsage | |
| ) -> None: | |
| """Extract usage information from preprocessing trace.""" | |
| pre_processing = trace_data.get("preProcessingTrace", {}) | |
| if not pre_processing: | |
| return | |
| model_output = pre_processing.get("modelInvocationOutput", {}) | |
| if not model_output: | |
| return | |
| metadata = model_output.get("metadata", {}) | |
| if not metadata: | |
| return | |
| usage: Optional[Union[InvokeAgentUsage, Dict]] = metadata.get("usage", {}) | |
| if not usage: | |
| return | |
| usage_info["inputTokens"] += usage.get("inputTokens", 0) | |
| usage_info["outputTokens"] += usage.get("outputTokens", 0) | |
| def _extract_orchestration_model( | |
| self, trace_data: InvokeAgentTrace | |
| ) -> Optional[str]: | |
| """Extract model information from orchestration trace.""" | |
| orchestration_trace = trace_data.get("orchestrationTrace", {}) | |
| if not orchestration_trace: | |
| return None | |
| model_invocation = orchestration_trace.get("modelInvocationInput", {}) | |
| if not model_invocation: | |
| return None | |
| return model_invocation.get("foundationModel") | |
| def _build_model_response( | |
| self, | |
| content: str, | |
| model: str, | |
| usage_info: InvokeAgentUsage, | |
| model_response: ModelResponse, | |
| ) -> ModelResponse: | |
| """Build the final ModelResponse object.""" | |
| # Create the message content | |
| message = Message(content=content, role="assistant") | |
| # Create choices | |
| choice = Choices(finish_reason="stop", index=0, message=message) | |
| # Update model response | |
| model_response.choices = [choice] | |
| model_response.model = usage_info.get("model", model) | |
| # Add usage information if available | |
| if usage_info: | |
| from litellm.types.utils import Usage | |
| usage = Usage( | |
| prompt_tokens=usage_info.get("inputTokens", 0), | |
| completion_tokens=usage_info.get("outputTokens", 0), | |
| total_tokens=usage_info.get("inputTokens", 0) | |
| + usage_info.get("outputTokens", 0), | |
| ) | |
| setattr(model_response, "usage", usage) | |
| return model_response | |
| def transform_response( | |
| self, | |
| model: str, | |
| raw_response: httpx.Response, | |
| model_response: ModelResponse, | |
| logging_obj: LiteLLMLoggingObj, | |
| request_data: dict, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| encoding: Any, | |
| api_key: Optional[str] = None, | |
| json_mode: Optional[bool] = None, | |
| ) -> ModelResponse: | |
| try: | |
| # Get the raw binary content | |
| raw_content = raw_response.content | |
| verbose_logger.debug( | |
| f"Processing {len(raw_content)} bytes of AWS event stream data" | |
| ) | |
| # Parse the AWS event stream format | |
| events = self._parse_aws_event_stream(raw_content) | |
| verbose_logger.debug(f"Parsed {len(events)} events from stream") | |
| # Extract response content from chunk events | |
| content = self._extract_response_content(events) | |
| # Extract usage information from trace events | |
| usage_info = self._extract_usage_info(events) | |
| # Build and return the model response | |
| return self._build_model_response( | |
| content=content, | |
| model=model, | |
| usage_info=usage_info, | |
| model_response=model_response, | |
| ) | |
| except Exception as e: | |
| verbose_logger.error( | |
| f"Error processing Bedrock Invoke Agent response: {str(e)}" | |
| ) | |
| raise BedrockError( | |
| message=f"Error processing response: {str(e)}", | |
| status_code=raw_response.status_code, | |
| ) | |
| def validate_environment( | |
| self, | |
| headers: dict, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| ) -> dict: | |
| return headers | |
| def get_error_class( | |
| self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] | |
| ) -> BaseLLMException: | |
| return BedrockError(status_code=status_code, message=error_message) | |
| def should_fake_stream( | |
| self, | |
| model: Optional[str], | |
| stream: Optional[bool], | |
| custom_llm_provider: Optional[str] = None, | |
| ) -> bool: | |
| return True | |