Spaces:
Paused
Paused
| import React, { useState, useEffect } from 'react'; | |
| import { Card, Form, Typography, Select, Input, Switch, Tooltip, Modal, message, Divider, Space, Tag, Image, Steps } from 'antd'; | |
| import { Button, TextInput } from '@tremor/react'; | |
| import type { FormInstance } from 'antd'; | |
| import { GuardrailProviders, guardrail_provider_map, shouldRenderPIIConfigSettings, guardrailLogoMap } from './guardrail_info_helpers'; | |
| import { createGuardrailCall, getGuardrailUISettings, getGuardrailProviderSpecificParams } from '../networking'; | |
| import PiiConfiguration from './pii_configuration'; | |
| import GuardrailProviderFields from './guardrail_provider_fields'; | |
| const { Title, Text, Link } = Typography; | |
| const { Option } = Select; | |
| const { Step } = Steps; | |
| // Define human-friendly descriptions for each mode | |
| const modeDescriptions = { | |
| pre_call: "Before LLM Call - Runs before the LLM call and checks the input (Recommended)", | |
| during_call: "During LLM Call - Runs in parallel with the LLM call, with response held until check completes", | |
| post_call: "After LLM Call - Runs after the LLM call and checks only the output", | |
| logging_only: "Logging Only - Only runs on logging callbacks without affecting the LLM call" | |
| }; | |
| interface AddGuardrailFormProps { | |
| visible: boolean; | |
| onClose: () => void; | |
| accessToken: string | null; | |
| onSuccess: () => void; | |
| } | |
| interface GuardrailSettings { | |
| supported_entities: string[]; | |
| supported_actions: string[]; | |
| supported_modes: string[]; | |
| pii_entity_categories: Array<{ | |
| category: string; | |
| entities: string[]; | |
| }>; | |
| } | |
| interface LiteLLMParams { | |
| guardrail: string; | |
| mode: string; | |
| default_on: boolean; | |
| [key: string]: any; // Allow additional properties for specific guardrails | |
| } | |
| // Mapping of provider -> list of param descriptors | |
| interface ProviderParam { | |
| param: string; | |
| description: string; | |
| required: boolean; | |
| default_value?: string; | |
| options?: string[]; | |
| type?: string; | |
| } | |
| interface ProviderParamsResponse { | |
| [provider: string]: ProviderParam[]; | |
| } | |
| const AddGuardrailForm: React.FC<AddGuardrailFormProps> = ({ | |
| visible, | |
| onClose, | |
| accessToken, | |
| onSuccess | |
| }) => { | |
| const [form] = Form.useForm(); | |
| const [loading, setLoading] = useState(false); | |
| const [selectedProvider, setSelectedProvider] = useState<string | null>(null); | |
| const [guardrailSettings, setGuardrailSettings] = useState<GuardrailSettings | null>(null); | |
| const [selectedEntities, setSelectedEntities] = useState<string[]>([]); | |
| const [selectedActions, setSelectedActions] = useState<{[key: string]: string}>({}); | |
| const [currentStep, setCurrentStep] = useState(0); | |
| const [providerParams, setProviderParams] = useState<ProviderParamsResponse | null>(null); | |
| // Fetch guardrail UI settings + provider params on mount / accessToken change | |
| useEffect(() => { | |
| if (!accessToken) return; | |
| const fetchData = async () => { | |
| try { | |
| // Parallel requests for speed | |
| const [uiSettings, providerParamsResp] = await Promise.all([ | |
| getGuardrailUISettings(accessToken), | |
| getGuardrailProviderSpecificParams(accessToken), | |
| ]); | |
| setGuardrailSettings(uiSettings); | |
| setProviderParams(providerParamsResp); | |
| } catch (error) { | |
| console.error('Error fetching guardrail data:', error); | |
| message.error('Failed to load guardrail configuration'); | |
| } | |
| }; | |
| fetchData(); | |
| }, [accessToken]); | |
| const handleProviderChange = (value: string) => { | |
| setSelectedProvider(value); | |
| // Reset form fields that are provider-specific | |
| form.setFieldsValue({ | |
| config: undefined, | |
| presidio_analyzer_api_base: undefined, | |
| presidio_anonymizer_api_base: undefined | |
| }); | |
| // Reset PII selections when changing provider | |
| setSelectedEntities([]); | |
| setSelectedActions({}); | |
| }; | |
| const handleEntitySelect = (entity: string) => { | |
| setSelectedEntities(prev => { | |
| if (prev.includes(entity)) { | |
| return prev.filter(e => e !== entity); | |
| } else { | |
| return [...prev, entity]; | |
| } | |
| }); | |
| }; | |
| const handleActionSelect = (entity: string, action: string) => { | |
| setSelectedActions(prev => ({ | |
| ...prev, | |
| [entity]: action | |
| })); | |
| }; | |
| const nextStep = async () => { | |
| try { | |
| // Validate current step fields | |
| if (currentStep === 0) { | |
| await form.validateFields(['guardrail_name', 'provider', 'mode', 'default_on']); | |
| // Also validate provider-specific fields if applicable | |
| if (selectedProvider) { | |
| // This will automatically validate any required fields for the selected provider | |
| const fieldsToValidate = ['guardrail_name', 'provider', 'mode', 'default_on']; | |
| if (selectedProvider === 'PresidioPII') { | |
| fieldsToValidate.push('presidio_analyzer_api_base', 'presidio_anonymizer_api_base'); | |
| } | |
| await form.validateFields(fieldsToValidate); | |
| } | |
| } | |
| setCurrentStep(currentStep + 1); | |
| } catch (error) { | |
| console.error("Form validation failed:", error); | |
| } | |
| }; | |
| const prevStep = () => { | |
| setCurrentStep(currentStep - 1); | |
| }; | |
| const resetForm = () => { | |
| form.resetFields(); | |
| setSelectedProvider(null); | |
| setSelectedEntities([]); | |
| setSelectedActions({}); | |
| setCurrentStep(0); | |
| }; | |
| const handleClose = () => { | |
| resetForm(); | |
| onClose(); | |
| }; | |
| const handleSubmit = async () => { | |
| try { | |
| setLoading(true); | |
| // First validate currently visible fields | |
| await form.validateFields(); | |
| // After validation, fetch *all* form values (including those from previous steps) | |
| const values = form.getFieldsValue(true); | |
| // Get the guardrail provider value from the map | |
| const guardrailProvider = guardrail_provider_map[values.provider]; | |
| // Prepare the guardrail data with proper typings | |
| const guardrailData: { | |
| guardrail_name: string; | |
| litellm_params: { | |
| guardrail: string; | |
| mode: string; | |
| default_on: boolean; | |
| [key: string]: any; // Allow dynamic properties | |
| }; | |
| guardrail_info: any; | |
| } = { | |
| guardrail_name: values.guardrail_name, | |
| litellm_params: { | |
| guardrail: guardrailProvider, | |
| mode: values.mode, | |
| default_on: values.default_on | |
| }, | |
| guardrail_info: {} | |
| }; | |
| // For Presidio PII, add the entity and action configurations | |
| if (values.provider === 'PresidioPII' && selectedEntities.length > 0) { | |
| const piiEntitiesConfig: {[key: string]: string} = {}; | |
| selectedEntities.forEach(entity => { | |
| piiEntitiesConfig[entity] = selectedActions[entity] || 'MASK'; // Default to MASK if no action selected | |
| }); | |
| guardrailData.litellm_params.pii_entities_config = piiEntitiesConfig; | |
| // Add Presidio API bases if provided | |
| if (values.presidio_analyzer_api_base) { | |
| guardrailData.litellm_params.presidio_analyzer_api_base = values.presidio_analyzer_api_base; | |
| } | |
| if (values.presidio_anonymizer_api_base) { | |
| guardrailData.litellm_params.presidio_anonymizer_api_base = values.presidio_anonymizer_api_base; | |
| } | |
| } | |
| // Add config values to the guardrail_info if provided | |
| else if (values.config) { | |
| try { | |
| const configObj = JSON.parse(values.config); | |
| // For some guardrails, the config values need to be in litellm_params | |
| guardrailData.guardrail_info = configObj; | |
| } catch (error) { | |
| message.error('Invalid JSON in configuration'); | |
| setLoading(false); | |
| return; | |
| } | |
| } | |
| /****************************** | |
| * Add provider-specific params | |
| * ---------------------------------- | |
| * The backend exposes exactly which extra parameters a provider | |
| * accepts via `/guardrails/ui/provider_specific_params`. | |
| * Instead of copying every unknown form field, we fetch the list for | |
| * the selected provider and ONLY pass those recognised params. | |
| ******************************/ | |
| // Use pre-fetched provider params to copy recognised params | |
| if (providerParams && selectedProvider) { | |
| const providerKey = guardrail_provider_map[selectedProvider]?.toLowerCase(); | |
| const providerSpecificParams = providerParams[providerKey] || []; | |
| const allowedParams = new Set<string>( | |
| providerSpecificParams.map((p) => p.param) | |
| ); | |
| allowedParams.forEach((paramName) => { | |
| const paramValue = values[paramName]; | |
| if (paramValue !== undefined && paramValue !== null && paramValue !== '') { | |
| guardrailData.litellm_params[paramName] = paramValue; | |
| } | |
| }); | |
| } | |
| if (!accessToken) { | |
| throw new Error("No access token available"); | |
| } | |
| console.log("Sending guardrail data:", JSON.stringify(guardrailData)); | |
| await createGuardrailCall(accessToken, guardrailData); | |
| message.success('Guardrail created successfully'); | |
| // Reset form and close modal | |
| resetForm(); | |
| onSuccess(); | |
| onClose(); | |
| } catch (error) { | |
| console.error("Failed to create guardrail:", error); | |
| message.error('Failed to create guardrail: ' + (error instanceof Error ? error.message : String(error))); | |
| } finally { | |
| setLoading(false); | |
| } | |
| }; | |
| const renderBasicInfo = () => { | |
| return ( | |
| <> | |
| <Form.Item | |
| name="guardrail_name" | |
| label="Guardrail Name" | |
| rules={[{ required: true, message: 'Please enter a guardrail name' }]} | |
| > | |
| <TextInput placeholder="Enter a name for this guardrail" /> | |
| </Form.Item> | |
| <Form.Item | |
| name="provider" | |
| label="Guardrail Provider" | |
| rules={[{ required: true, message: 'Please select a provider' }]} | |
| > | |
| <Select | |
| placeholder="Select a guardrail provider" | |
| onChange={handleProviderChange} | |
| labelInValue={false} | |
| optionLabelProp="label" | |
| dropdownRender={menu => menu} | |
| showSearch={true} | |
| > | |
| {Object.entries(GuardrailProviders).map(([key, value]) => ( | |
| <Option | |
| key={key} | |
| value={key} | |
| label={ | |
| <div style={{ display: 'flex', alignItems: 'center' }}> | |
| {guardrailLogoMap[value] && ( | |
| <img | |
| src={guardrailLogoMap[value]} | |
| alt="" | |
| style={{ | |
| height: '20px', | |
| width: '20px', | |
| marginRight: '8px', | |
| objectFit: 'contain' | |
| }} | |
| onError={(e) => { | |
| // Hide broken image icon if image fails to load | |
| e.currentTarget.style.display = 'none'; | |
| }} | |
| /> | |
| )} | |
| <span>{value}</span> | |
| </div> | |
| } | |
| > | |
| <div style={{ display: 'flex', alignItems: 'center' }}> | |
| {guardrailLogoMap[value] && ( | |
| <img | |
| src={guardrailLogoMap[value]} | |
| alt="" | |
| style={{ | |
| height: '20px', | |
| width: '20px', | |
| marginRight: '8px', | |
| objectFit: 'contain' | |
| }} | |
| onError={(e) => { | |
| // Hide broken image icon if image fails to load | |
| e.currentTarget.style.display = 'none'; | |
| }} | |
| /> | |
| )} | |
| <span>{value}</span> | |
| </div> | |
| </Option> | |
| ))} | |
| </Select> | |
| </Form.Item> | |
| <Form.Item | |
| name="mode" | |
| label="Mode" | |
| tooltip="How the guardrail should be applied" | |
| rules={[{ required: true, message: 'Please select a mode' }]} | |
| > | |
| <Select | |
| optionLabelProp="label" | |
| mode="multiple" | |
| > | |
| {guardrailSettings?.supported_modes?.map(mode => ( | |
| <Option key={mode} value={mode} label={mode}> | |
| <div> | |
| <div> | |
| <strong>{mode}</strong> | |
| {mode === 'pre_call' && <Tag color="green" style={{ marginLeft: '8px' }}>Recommended</Tag>} | |
| </div> | |
| <div style={{ fontSize: '12px', color: '#888' }}>{modeDescriptions[mode as keyof typeof modeDescriptions]}</div> | |
| </div> | |
| </Option> | |
| )) || ( | |
| <> | |
| <Option value="pre_call" label="pre_call"> | |
| <div> | |
| <div><strong>pre_call</strong> <Tag color="green">Recommended</Tag></div> | |
| <div style={{ fontSize: '12px', color: '#888' }}>{modeDescriptions.pre_call}</div> | |
| </div> | |
| </Option> | |
| <Option value="during_call" label="during_call"> | |
| <div> | |
| <div><strong>during_call</strong></div> | |
| <div style={{ fontSize: '12px', color: '#888' }}>{modeDescriptions.during_call}</div> | |
| </div> | |
| </Option> | |
| <Option value="post_call" label="post_call"> | |
| <div> | |
| <div><strong>post_call</strong></div> | |
| <div style={{ fontSize: '12px', color: '#888' }}>{modeDescriptions.post_call}</div> | |
| </div> | |
| </Option> | |
| <Option value="logging_only" label="logging_only"> | |
| <div> | |
| <div><strong>logging_only</strong></div> | |
| <div style={{ fontSize: '12px', color: '#888' }}>{modeDescriptions.logging_only}</div> | |
| </div> | |
| </Option> | |
| </> | |
| )} | |
| </Select> | |
| </Form.Item> | |
| <Form.Item | |
| name="default_on" | |
| label="Always On" | |
| tooltip="If enabled, this guardrail will be applied to all requests by default." | |
| > | |
| <Select> | |
| <Select.Option value={true}>Yes</Select.Option> | |
| <Select.Option value={false}>No</Select.Option> | |
| </Select> | |
| </Form.Item> | |
| {/* Use the GuardrailProviderFields component to render provider-specific fields */} | |
| <GuardrailProviderFields | |
| selectedProvider={selectedProvider} | |
| accessToken={accessToken} | |
| providerParams={providerParams} | |
| /> | |
| </> | |
| ); | |
| }; | |
| const renderPiiConfiguration = () => { | |
| if (!guardrailSettings || selectedProvider !== 'PresidioPII') return null; | |
| return ( | |
| <PiiConfiguration | |
| entities={guardrailSettings.supported_entities} | |
| actions={guardrailSettings.supported_actions} | |
| selectedEntities={selectedEntities} | |
| selectedActions={selectedActions} | |
| onEntitySelect={handleEntitySelect} | |
| onActionSelect={handleActionSelect} | |
| entityCategories={guardrailSettings.pii_entity_categories} | |
| /> | |
| ); | |
| }; | |
| const renderStepContent = () => { | |
| switch (currentStep) { | |
| case 0: | |
| return renderBasicInfo(); | |
| case 1: | |
| if (shouldRenderPIIConfigSettings(selectedProvider)) { | |
| return renderPiiConfiguration(); | |
| } | |
| default: | |
| return null; | |
| } | |
| }; | |
| const renderStepButtons = () => { | |
| return ( | |
| <div className="flex justify-end space-x-2 mt-4"> | |
| {currentStep > 0 && ( | |
| <Button | |
| variant="secondary" | |
| onClick={prevStep} | |
| > | |
| Previous | |
| </Button> | |
| )} | |
| {currentStep < 1 && ( | |
| <Button | |
| onClick={nextStep} | |
| > | |
| Next | |
| </Button> | |
| )} | |
| {currentStep === 1 && ( | |
| <Button | |
| onClick={handleSubmit} | |
| loading={loading} | |
| > | |
| Create Guardrail | |
| </Button> | |
| )} | |
| <Button | |
| variant="secondary" | |
| onClick={handleClose} | |
| > | |
| Cancel | |
| </Button> | |
| </div> | |
| ); | |
| }; | |
| return ( | |
| <Modal | |
| title="Add Guardrail" | |
| open={visible} | |
| onCancel={handleClose} | |
| footer={null} | |
| width={700} | |
| > | |
| <Form | |
| form={form} | |
| layout="vertical" | |
| initialValues={{ | |
| mode: "pre_call", | |
| default_on: false | |
| }} | |
| > | |
| <Steps current={currentStep} className="mb-6"> | |
| <Step title="Basic Info" /> | |
| <Step title={selectedProvider === 'PresidioPII' ? "PII Configuration" : "Provider Configuration"} /> | |
| </Steps> | |
| {renderStepContent()} | |
| {renderStepButtons()} | |
| </Form> | |
| </Modal> | |
| ); | |
| }; | |
| export default AddGuardrailForm; |