Spaces:
Paused
Paused
| import ast | |
| from typing import List, Dict, Set, Optional | |
| import os | |
| from dataclasses import dataclass | |
| import argparse | |
| import re | |
| import sys | |
| sys.path.insert( | |
| 0, os.path.abspath("../..") | |
| ) # Adds the parent directory to the system path | |
| import litellm | |
| class FunctionInfo: | |
| """Store function information.""" | |
| name: str | |
| docstring: Optional[str] | |
| parameters: Set[str] | |
| file_path: str | |
| line_number: int | |
| class FastAPIDocVisitor(ast.NodeVisitor): | |
| """AST visitor to find FastAPI endpoint functions.""" | |
| def __init__(self, target_functions: Set[str]): | |
| self.target_functions = target_functions | |
| self.functions: Dict[str, FunctionInfo] = {} | |
| self.current_file = "" | |
| def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None: | |
| """Visit function definitions (both async and sync) and collect info if they match target functions.""" | |
| if node.name in self.target_functions: | |
| # Extract docstring | |
| docstring = ast.get_docstring(node) | |
| # Extract parameters | |
| parameters = set() | |
| for arg in node.args.args: | |
| if arg.annotation is not None: | |
| # Get the parameter type from annotation | |
| if isinstance(arg.annotation, ast.Name): | |
| parameters.add((arg.arg, arg.annotation.id)) | |
| elif isinstance(arg.annotation, ast.Subscript): | |
| if isinstance(arg.annotation.value, ast.Name): | |
| parameters.add((arg.arg, arg.annotation.value.id)) | |
| self.functions[node.name] = FunctionInfo( | |
| name=node.name, | |
| docstring=docstring, | |
| parameters=parameters, | |
| file_path=self.current_file, | |
| line_number=node.lineno, | |
| ) | |
| # Also need to add this to handle async functions | |
| def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: | |
| """Handle async functions by delegating to the regular function visitor.""" | |
| return self.visit_FunctionDef(node) | |
| def find_functions_in_file( | |
| file_path: str, target_functions: Set[str] | |
| ) -> Dict[str, FunctionInfo]: | |
| """Find target functions in a Python file using AST.""" | |
| try: | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| content = f.read() | |
| visitor = FastAPIDocVisitor(target_functions) | |
| visitor.current_file = file_path | |
| tree = ast.parse(content) | |
| visitor.visit(tree) | |
| return visitor.functions | |
| except Exception as e: | |
| print(f"Error parsing {file_path}: {str(e)}") | |
| return {} | |
| def extract_docstring_params(docstring: Optional[str]) -> Set[str]: | |
| """Extract parameter names from docstring.""" | |
| if not docstring: | |
| return set() | |
| params = set() | |
| # Match parameters in format: | |
| # - parameter_name: description | |
| # or | |
| # parameter_name: description | |
| param_pattern = r"-?\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*(?:\([^)]*\))?\s*:" | |
| for match in re.finditer(param_pattern, docstring): | |
| params.add(match.group(1)) | |
| return params | |
| def analyze_function(func_info: FunctionInfo) -> Dict: | |
| """Analyze function documentation and return validation results.""" | |
| docstring_params = extract_docstring_params(func_info.docstring) | |
| print(f"func_info.parameters: {func_info.parameters}") | |
| pydantic_params = set() | |
| for name, type_name in func_info.parameters: | |
| if type_name.endswith("Request") or type_name.endswith("Response"): | |
| pydantic_model = getattr(litellm.proxy._types, type_name, None) | |
| if pydantic_model is not None: | |
| for param in pydantic_model.model_fields.keys(): | |
| pydantic_params.add(param) | |
| print(f"pydantic_params: {pydantic_params}") | |
| missing_params = pydantic_params - docstring_params | |
| return { | |
| "function": func_info.name, | |
| "file_path": func_info.file_path, | |
| "line_number": func_info.line_number, | |
| "has_docstring": bool(func_info.docstring), | |
| "pydantic_params": list(pydantic_params), | |
| "documented_params": list(docstring_params), | |
| "missing_params": list(missing_params), | |
| "is_valid": len(missing_params) == 0, | |
| } | |
| def print_validation_results(results: Dict) -> None: | |
| """Print validation results in a readable format.""" | |
| print(f"\nChecking function: {results['function']}") | |
| print(f"File: {results['file_path']}:{results['line_number']}") | |
| print("-" * 50) | |
| if not results["has_docstring"]: | |
| print("❌ No docstring found!") | |
| return | |
| if not results["pydantic_params"]: | |
| print("ℹ️ No Pydantic input models found.") | |
| return | |
| if results["is_valid"]: | |
| print("✅ All Pydantic parameters are documented!") | |
| else: | |
| print("❌ Missing documentation for parameters:") | |
| for param in sorted(results["missing_params"]): | |
| print(f" - {param}") | |
| def main(): | |
| function_names = [ | |
| "new_end_user", | |
| "end_user_info", | |
| "update_end_user", | |
| "delete_end_user", | |
| "generate_key_fn", | |
| "info_key_fn", | |
| "update_key_fn", | |
| "delete_key_fn", | |
| "new_user", | |
| "new_team", | |
| "team_info", | |
| "update_team", | |
| "delete_team", | |
| "new_organization", | |
| "update_organization", | |
| "delete_organization", | |
| "list_organization", | |
| "user_update", | |
| "new_budget", | |
| "info_budget", | |
| "update_budget", | |
| "delete_budget", | |
| "list_budget", | |
| ] | |
| # directory = "../../litellm/proxy/management_endpoints" # LOCAL | |
| directory = "./litellm/proxy/management_endpoints" | |
| # Convert function names to set for faster lookup | |
| target_functions = set(function_names) | |
| found_functions: Dict[str, FunctionInfo] = {} | |
| # Walk through directory | |
| for root, _, files in os.walk(directory): | |
| for file in files: | |
| if file.endswith(".py"): | |
| file_path = os.path.join(root, file) | |
| found = find_functions_in_file(file_path, target_functions) | |
| found_functions.update(found) | |
| # Analyze and output results | |
| for func_name in function_names: | |
| if func_name in found_functions: | |
| result = analyze_function(found_functions[func_name]) | |
| if not result["is_valid"]: | |
| raise Exception(print_validation_results(result)) | |
| # results.append(result) | |
| # print_validation_results(result) | |
| # # Exit with error code if any validation failed | |
| # if any(not r["is_valid"] for r in results): | |
| # exit(1) | |
| if __name__ == "__main__": | |
| main() | |