Spaces:
Paused
Paused
| import asyncio | |
| import json | |
| import os | |
| import sys | |
| import uuid | |
| from typing import Optional, cast | |
| from unittest.mock import AsyncMock, MagicMock, patch | |
| import pytest | |
| from fastapi import Request | |
| from fastapi.testclient import TestClient | |
| sys.path.insert( | |
| 0, os.path.abspath("../../../") | |
| ) # Adds the parent directory to the system path | |
| import litellm | |
| from litellm.proxy._types import NewTeamRequest | |
| from litellm.proxy.auth.handle_jwt import JWTHandler | |
| from litellm.proxy.management_endpoints.types import CustomOpenID | |
| from litellm.proxy.management_endpoints.ui_sso import ( | |
| GoogleSSOHandler, | |
| MicrosoftSSOHandler, | |
| SSOAuthenticationHandler, | |
| ) | |
| from litellm.types.proxy.management_endpoints.ui_sso import ( | |
| DefaultTeamSSOParams, | |
| MicrosoftGraphAPIUserGroupDirectoryObject, | |
| MicrosoftGraphAPIUserGroupResponse, | |
| MicrosoftServicePrincipalTeam, | |
| ) | |
| def test_microsoft_sso_handler_openid_from_response_user_principal_name(): | |
| # Arrange | |
| # Create a mock response similar to what Microsoft SSO would return | |
| mock_response = { | |
| "userPrincipalName": "test@example.com", | |
| "displayName": "Test User", | |
| "id": "user123", | |
| "givenName": "Test", | |
| "surname": "User", | |
| "some_other_field": "value", | |
| } | |
| expected_team_ids = ["team1", "team2"] | |
| # Act | |
| # Call the method being tested | |
| result = MicrosoftSSOHandler.openid_from_response( | |
| response=mock_response, team_ids=expected_team_ids | |
| ) | |
| # Assert | |
| # Check that the result is a CustomOpenID object with the expected values | |
| assert isinstance(result, CustomOpenID) | |
| assert result.email == "test@example.com" | |
| assert result.display_name == "Test User" | |
| assert result.provider == "microsoft" | |
| assert result.id == "user123" | |
| assert result.first_name == "Test" | |
| assert result.last_name == "User" | |
| assert result.team_ids == expected_team_ids | |
| def test_microsoft_sso_handler_openid_from_response(): | |
| # Arrange | |
| # Create a mock response similar to what Microsoft SSO would return | |
| mock_response = { | |
| "mail": "test@example.com", | |
| "displayName": "Test User", | |
| "id": "user123", | |
| "givenName": "Test", | |
| "surname": "User", | |
| "some_other_field": "value", | |
| } | |
| expected_team_ids = ["team1", "team2"] | |
| # Act | |
| # Call the method being tested | |
| result = MicrosoftSSOHandler.openid_from_response( | |
| response=mock_response, team_ids=expected_team_ids | |
| ) | |
| # Assert | |
| # Check that the result is a CustomOpenID object with the expected values | |
| assert isinstance(result, CustomOpenID) | |
| assert result.email == "test@example.com" | |
| assert result.display_name == "Test User" | |
| assert result.provider == "microsoft" | |
| assert result.id == "user123" | |
| assert result.first_name == "Test" | |
| assert result.last_name == "User" | |
| assert result.team_ids == expected_team_ids | |
| def test_microsoft_sso_handler_with_empty_response(): | |
| # Arrange | |
| # Test with None response | |
| # Act | |
| result = MicrosoftSSOHandler.openid_from_response(response=None, team_ids=[]) | |
| # Assert | |
| assert isinstance(result, CustomOpenID) | |
| assert result.email is None | |
| assert result.display_name is None | |
| assert result.provider == "microsoft" | |
| assert result.id is None | |
| assert result.first_name is None | |
| assert result.last_name is None | |
| assert result.team_ids == [] | |
| def test_get_microsoft_callback_response(): | |
| # Arrange | |
| mock_request = MagicMock(spec=Request) | |
| mock_response = { | |
| "mail": "microsoft_user@example.com", | |
| "displayName": "Microsoft User", | |
| "id": "msft123", | |
| "givenName": "Microsoft", | |
| "surname": "User", | |
| } | |
| future = asyncio.Future() | |
| future.set_result(mock_response) | |
| with patch.dict( | |
| os.environ, | |
| {"MICROSOFT_CLIENT_SECRET": "mock_secret", "MICROSOFT_TENANT": "mock_tenant"}, | |
| ): | |
| with patch( | |
| "fastapi_sso.sso.microsoft.MicrosoftSSO.verify_and_process", | |
| return_value=future, | |
| ): | |
| # Act | |
| result = asyncio.run( | |
| MicrosoftSSOHandler.get_microsoft_callback_response( | |
| request=mock_request, | |
| microsoft_client_id="mock_client_id", | |
| redirect_url="http://mock_redirect_url", | |
| ) | |
| ) | |
| # Assert | |
| assert isinstance(result, CustomOpenID) | |
| assert result.email == "microsoft_user@example.com" | |
| assert result.display_name == "Microsoft User" | |
| assert result.provider == "microsoft" | |
| assert result.id == "msft123" | |
| assert result.first_name == "Microsoft" | |
| assert result.last_name == "User" | |
| def test_get_microsoft_callback_response_raw_sso_response(): | |
| # Arrange | |
| mock_request = MagicMock(spec=Request) | |
| mock_response = { | |
| "mail": "microsoft_user@example.com", | |
| "displayName": "Microsoft User", | |
| "id": "msft123", | |
| "givenName": "Microsoft", | |
| "surname": "User", | |
| } | |
| future = asyncio.Future() | |
| future.set_result(mock_response) | |
| with patch.dict( | |
| os.environ, | |
| {"MICROSOFT_CLIENT_SECRET": "mock_secret", "MICROSOFT_TENANT": "mock_tenant"}, | |
| ): | |
| with patch( | |
| "fastapi_sso.sso.microsoft.MicrosoftSSO.verify_and_process", | |
| return_value=future, | |
| ): | |
| # Act | |
| result = asyncio.run( | |
| MicrosoftSSOHandler.get_microsoft_callback_response( | |
| request=mock_request, | |
| microsoft_client_id="mock_client_id", | |
| redirect_url="http://mock_redirect_url", | |
| return_raw_sso_response=True, | |
| ) | |
| ) | |
| # Assert | |
| print("result from verify_and_process", result) | |
| assert isinstance(result, dict) | |
| assert result["mail"] == "microsoft_user@example.com" | |
| assert result["displayName"] == "Microsoft User" | |
| assert result["id"] == "msft123" | |
| assert result["givenName"] == "Microsoft" | |
| assert result["surname"] == "User" | |
| def test_get_google_callback_response(): | |
| # Arrange | |
| mock_request = MagicMock(spec=Request) | |
| mock_response = { | |
| "email": "google_user@example.com", | |
| "name": "Google User", | |
| "sub": "google123", | |
| "given_name": "Google", | |
| "family_name": "User", | |
| } | |
| future = asyncio.Future() | |
| future.set_result(mock_response) | |
| with patch.dict(os.environ, {"GOOGLE_CLIENT_SECRET": "mock_secret"}): | |
| with patch( | |
| "fastapi_sso.sso.google.GoogleSSO.verify_and_process", return_value=future | |
| ): | |
| # Act | |
| result = asyncio.run( | |
| GoogleSSOHandler.get_google_callback_response( | |
| request=mock_request, | |
| google_client_id="mock_client_id", | |
| redirect_url="http://mock_redirect_url", | |
| ) | |
| ) | |
| # Assert | |
| assert isinstance(result, dict) | |
| assert result.get("email") == "google_user@example.com" | |
| assert result.get("name") == "Google User" | |
| assert result.get("sub") == "google123" | |
| assert result.get("given_name") == "Google" | |
| assert result.get("family_name") == "User" | |
| async def test_get_user_groups_from_graph_api(): | |
| # Arrange | |
| mock_response = { | |
| "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#directoryObjects", | |
| "value": [ | |
| { | |
| "@odata.type": "#microsoft.graph.group", | |
| "id": "group1", | |
| "displayName": "Group 1", | |
| }, | |
| { | |
| "@odata.type": "#microsoft.graph.group", | |
| "id": "group2", | |
| "displayName": "Group 2", | |
| }, | |
| ], | |
| } | |
| async def mock_get(*args, **kwargs): | |
| mock = MagicMock() | |
| mock.json.return_value = mock_response | |
| return mock | |
| with patch( | |
| "litellm.proxy.management_endpoints.ui_sso.get_async_httpx_client" | |
| ) as mock_client: | |
| mock_client.return_value = MagicMock() | |
| mock_client.return_value.get = mock_get | |
| # Act | |
| result = await MicrosoftSSOHandler.get_user_groups_from_graph_api( | |
| access_token="mock_token" | |
| ) | |
| # Assert | |
| assert isinstance(result, list) | |
| assert len(result) == 2 | |
| assert "group1" in result | |
| assert "group2" in result | |
| async def test_get_user_groups_pagination(): | |
| # Arrange | |
| first_response = { | |
| "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#directoryObjects", | |
| "@odata.nextLink": "https://graph.microsoft.com/v1.0/me/memberOf?$skiptoken=page2", | |
| "value": [ | |
| { | |
| "@odata.type": "#microsoft.graph.group", | |
| "id": "group1", | |
| "displayName": "Group 1", | |
| }, | |
| ], | |
| } | |
| second_response = { | |
| "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#directoryObjects", | |
| "value": [ | |
| { | |
| "@odata.type": "#microsoft.graph.group", | |
| "id": "group2", | |
| "displayName": "Group 2", | |
| }, | |
| ], | |
| } | |
| responses = [first_response, second_response] | |
| current_response = {"index": 0} | |
| async def mock_get(*args, **kwargs): | |
| mock = MagicMock() | |
| mock.json.return_value = responses[current_response["index"]] | |
| current_response["index"] += 1 | |
| return mock | |
| with patch( | |
| "litellm.proxy.management_endpoints.ui_sso.get_async_httpx_client" | |
| ) as mock_client: | |
| mock_client.return_value = MagicMock() | |
| mock_client.return_value.get = mock_get | |
| # Act | |
| result = await MicrosoftSSOHandler.get_user_groups_from_graph_api( | |
| access_token="mock_token" | |
| ) | |
| # Assert | |
| assert isinstance(result, list) | |
| assert len(result) == 2 | |
| assert "group1" in result | |
| assert "group2" in result | |
| assert current_response["index"] == 2 # Verify both pages were fetched | |
| async def test_get_user_groups_empty_response(): | |
| # Arrange | |
| mock_response = { | |
| "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#directoryObjects", | |
| "value": [], | |
| } | |
| async def mock_get(*args, **kwargs): | |
| mock = MagicMock() | |
| mock.json.return_value = mock_response | |
| return mock | |
| with patch( | |
| "litellm.proxy.management_endpoints.ui_sso.get_async_httpx_client" | |
| ) as mock_client: | |
| mock_client.return_value = MagicMock() | |
| mock_client.return_value.get = mock_get | |
| # Act | |
| result = await MicrosoftSSOHandler.get_user_groups_from_graph_api( | |
| access_token="mock_token" | |
| ) | |
| # Assert | |
| assert isinstance(result, list) | |
| assert len(result) == 0 | |
| async def test_get_user_groups_error_handling(): | |
| # Arrange | |
| async def mock_get(*args, **kwargs): | |
| raise Exception("API Error") | |
| with patch( | |
| "litellm.proxy.management_endpoints.ui_sso.get_async_httpx_client" | |
| ) as mock_client: | |
| mock_client.return_value = MagicMock() | |
| mock_client.return_value.get = mock_get | |
| # Act | |
| result = await MicrosoftSSOHandler.get_user_groups_from_graph_api( | |
| access_token="mock_token" | |
| ) | |
| # Assert | |
| assert isinstance(result, list) | |
| assert len(result) == 0 | |
| def test_get_group_ids_from_graph_api_response(): | |
| # Arrange | |
| mock_response = MicrosoftGraphAPIUserGroupResponse( | |
| odata_context="https://graph.microsoft.com/v1.0/$metadata#directoryObjects", | |
| odata_nextLink=None, | |
| value=[ | |
| MicrosoftGraphAPIUserGroupDirectoryObject( | |
| odata_type="#microsoft.graph.group", | |
| id="group1", | |
| displayName="Group 1", | |
| description=None, | |
| deletedDateTime=None, | |
| roleTemplateId=None, | |
| ), | |
| MicrosoftGraphAPIUserGroupDirectoryObject( | |
| odata_type="#microsoft.graph.group", | |
| id="group2", | |
| displayName="Group 2", | |
| description=None, | |
| deletedDateTime=None, | |
| roleTemplateId=None, | |
| ), | |
| MicrosoftGraphAPIUserGroupDirectoryObject( | |
| odata_type="#microsoft.graph.group", | |
| id=None, # Test handling of None id | |
| displayName="Invalid Group", | |
| description=None, | |
| deletedDateTime=None, | |
| roleTemplateId=None, | |
| ), | |
| ], | |
| ) | |
| # Act | |
| result = MicrosoftSSOHandler._get_group_ids_from_graph_api_response(mock_response) | |
| # Assert | |
| assert isinstance(result, list) | |
| assert len(result) == 2 | |
| assert "group1" in result | |
| assert "group2" in result | |
| async def test_default_team_params(team_params): | |
| """ | |
| When litellm.default_team_params is set, it should be used to create a new team | |
| """ | |
| # Arrange | |
| litellm.default_team_params = team_params | |
| def mock_jsonify_team_object(db_data): | |
| return db_data | |
| # Mock Prisma client | |
| mock_prisma = MagicMock() | |
| mock_prisma.db.litellm_teamtable.find_first = AsyncMock(return_value=None) | |
| mock_prisma.db.litellm_teamtable.create = AsyncMock() | |
| mock_prisma.get_data = AsyncMock(return_value=None) | |
| mock_prisma.jsonify_team_object = MagicMock(side_effect=mock_jsonify_team_object) | |
| with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma): | |
| # Act | |
| team_id = str(uuid.uuid4()) | |
| await MicrosoftSSOHandler.create_litellm_teams_from_service_principal_team_ids( | |
| service_principal_teams=[ | |
| MicrosoftServicePrincipalTeam( | |
| principalId=team_id, | |
| principalDisplayName="Test Team", | |
| ) | |
| ] | |
| ) | |
| # Assert | |
| # Verify team was created with correct parameters | |
| mock_prisma.db.litellm_teamtable.create.assert_called_once() | |
| print( | |
| "mock_prisma.db.litellm_teamtable.create.call_args", | |
| mock_prisma.db.litellm_teamtable.create.call_args, | |
| ) | |
| create_call_args = mock_prisma.db.litellm_teamtable.create.call_args.kwargs[ | |
| "data" | |
| ] | |
| assert create_call_args["team_id"] == team_id | |
| assert create_call_args["team_alias"] == "Test Team" | |
| assert create_call_args["max_budget"] == 10 | |
| assert create_call_args["budget_duration"] == "1d" | |
| assert create_call_args["models"] == ["special-gpt-5"] | |
| async def test_create_team_without_default_params(): | |
| """ | |
| Test team creation when litellm.default_team_params is None | |
| Should create team with just the basic required fields | |
| """ | |
| # Arrange | |
| litellm.default_team_params = None | |
| def mock_jsonify_team_object(db_data): | |
| return db_data | |
| # Mock Prisma client | |
| mock_prisma = MagicMock() | |
| mock_prisma.db.litellm_teamtable.find_first = AsyncMock(return_value=None) | |
| mock_prisma.db.litellm_teamtable.create = AsyncMock() | |
| mock_prisma.get_data = AsyncMock(return_value=None) | |
| mock_prisma.jsonify_team_object = MagicMock(side_effect=mock_jsonify_team_object) | |
| with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma): | |
| # Act | |
| team_id = str(uuid.uuid4()) | |
| await MicrosoftSSOHandler.create_litellm_teams_from_service_principal_team_ids( | |
| service_principal_teams=[ | |
| MicrosoftServicePrincipalTeam( | |
| principalId=team_id, | |
| principalDisplayName="Test Team", | |
| ) | |
| ] | |
| ) | |
| # Assert | |
| mock_prisma.db.litellm_teamtable.create.assert_called_once() | |
| create_call_args = mock_prisma.db.litellm_teamtable.create.call_args.kwargs[ | |
| "data" | |
| ] | |
| assert create_call_args["team_id"] == team_id | |
| assert create_call_args["team_alias"] == "Test Team" | |
| # Should not have any of the optional fields | |
| assert "max_budget" not in create_call_args | |
| assert "budget_duration" not in create_call_args | |
| assert create_call_args["models"] == [] | |
| def test_apply_user_info_values_to_sso_user_defined_values(): | |
| from litellm.proxy._types import LiteLLM_UserTable | |
| from litellm.proxy.management_endpoints.ui_sso import ( | |
| apply_user_info_values_to_sso_user_defined_values, | |
| ) | |
| user_info = LiteLLM_UserTable( | |
| user_id="123", | |
| user_email="test@example.com", | |
| user_role="admin", | |
| ) | |
| user_defined_values = { | |
| "user_id": "456", | |
| "user_email": "test@example.com", | |
| "user_role": "admin", | |
| } | |
| sso_user_defined_values = apply_user_info_values_to_sso_user_defined_values( | |
| user_info=user_info, | |
| user_defined_values=user_defined_values, | |
| ) | |
| assert sso_user_defined_values["user_id"] == "123" | |
| async def test_get_user_info_from_db(): | |
| """ | |
| received args in get_user_info_from_db: {'result': CustomOpenID(id='krrishd', email='krrishdholakia@gmail.com', first_name=None, last_name=None, display_name='a3f1c107-04dc-4c93-ae60-7f32eb4b05ce', picture=None, provider=None, team_ids=[]), 'prisma_client': <litellm.proxy.utils.PrismaClient object at 0x14a74e3c0>, 'user_api_key_cache': <litellm.caching.dual_cache.DualCache object at 0x148d37110>, 'proxy_logging_obj': <litellm.proxy.utils.ProxyLogging object at 0x148dd9090>, 'user_email': 'krrishdholakia@gmail.com', 'user_defined_values': {'models': [], 'user_id': 'krrishd', 'user_email': 'krrishdholakia@gmail.com', 'max_budget': None, 'user_role': None, 'budget_duration': None}} | |
| """ | |
| from litellm.proxy.management_endpoints.ui_sso import get_user_info_from_db | |
| prisma_client = MagicMock() | |
| user_api_key_cache = MagicMock() | |
| proxy_logging_obj = MagicMock() | |
| user_email = "krrishdholakia@gmail.com" | |
| user_defined_values = { | |
| "models": [], | |
| "user_id": "krrishd", | |
| "user_email": "krrishdholakia@gmail.com", | |
| "max_budget": None, | |
| "user_role": None, | |
| "budget_duration": None, | |
| } | |
| args = { | |
| "result": CustomOpenID( | |
| id="krrishd", | |
| email="krrishdholakia@gmail.com", | |
| first_name=None, | |
| last_name=None, | |
| display_name="a3f1c107-04dc-4c93-ae60-7f32eb4b05ce", | |
| picture=None, | |
| provider=None, | |
| team_ids=[], | |
| ), | |
| "prisma_client": prisma_client, | |
| "user_api_key_cache": user_api_key_cache, | |
| "proxy_logging_obj": proxy_logging_obj, | |
| "user_email": user_email, | |
| "user_defined_values": user_defined_values, | |
| } | |
| with patch.object( | |
| litellm.proxy.management_endpoints.ui_sso, "get_user_object" | |
| ) as mock_get_user_object: | |
| user_info = await get_user_info_from_db(**args) | |
| mock_get_user_object.assert_called_once() | |
| mock_get_user_object.call_args.kwargs["user_id"] = "krrishd" | |
| async def test_get_user_info_from_db_alternate_user_id(): | |
| from litellm.proxy.management_endpoints.ui_sso import get_user_info_from_db | |
| prisma_client = MagicMock() | |
| user_api_key_cache = MagicMock() | |
| proxy_logging_obj = MagicMock() | |
| user_email = "krrishdholakia@gmail.com" | |
| user_defined_values = { | |
| "models": [], | |
| "user_id": "krrishd", | |
| "user_email": "krrishdholakia@gmail.com", | |
| "max_budget": None, | |
| "user_role": None, | |
| "budget_duration": None, | |
| } | |
| args = { | |
| "result": CustomOpenID( | |
| id="krrishd", | |
| email="krrishdholakia@gmail.com", | |
| first_name=None, | |
| last_name=None, | |
| display_name="a3f1c107-04dc-4c93-ae60-7f32eb4b05ce", | |
| picture=None, | |
| provider=None, | |
| team_ids=[], | |
| ), | |
| "prisma_client": prisma_client, | |
| "user_api_key_cache": user_api_key_cache, | |
| "proxy_logging_obj": proxy_logging_obj, | |
| "user_email": user_email, | |
| "user_defined_values": user_defined_values, | |
| "alternate_user_id": "krrishd-email1234", | |
| } | |
| with patch.object( | |
| litellm.proxy.management_endpoints.ui_sso, "get_user_object" | |
| ) as mock_get_user_object: | |
| user_info = await get_user_info_from_db(**args) | |
| mock_get_user_object.assert_called_once() | |
| mock_get_user_object.call_args.kwargs["user_id"] = "krrishd-email1234" | |
| async def test_check_and_update_if_proxy_admin_id(): | |
| """ | |
| Test that a user with matching PROXY_ADMIN_ID gets their role updated to admin | |
| """ | |
| from litellm.proxy._types import LitellmUserRoles | |
| from litellm.proxy.management_endpoints.ui_sso import ( | |
| check_and_update_if_proxy_admin_id, | |
| ) | |
| # Mock Prisma client | |
| mock_prisma = MagicMock() | |
| mock_prisma.db.litellm_usertable.update = AsyncMock() | |
| # Set up test data | |
| test_user_id = "test_admin_123" | |
| test_user_role = "user" | |
| with patch.dict(os.environ, {"PROXY_ADMIN_ID": test_user_id}): | |
| # Act | |
| updated_role = await check_and_update_if_proxy_admin_id( | |
| user_role=test_user_role, user_id=test_user_id, prisma_client=mock_prisma | |
| ) | |
| # Assert | |
| assert updated_role == LitellmUserRoles.PROXY_ADMIN.value | |
| mock_prisma.db.litellm_usertable.update.assert_called_once_with( | |
| where={"user_id": test_user_id}, | |
| data={"user_role": LitellmUserRoles.PROXY_ADMIN.value}, | |
| ) | |
| async def test_check_and_update_if_proxy_admin_id_already_admin(): | |
| """ | |
| Test that a user who is already an admin doesn't get their role updated | |
| """ | |
| from litellm.proxy._types import LitellmUserRoles | |
| from litellm.proxy.management_endpoints.ui_sso import ( | |
| check_and_update_if_proxy_admin_id, | |
| ) | |
| # Mock Prisma client | |
| mock_prisma = MagicMock() | |
| mock_prisma.db.litellm_usertable.update = AsyncMock() | |
| # Set up test data | |
| test_user_id = "test_admin_123" | |
| test_user_role = LitellmUserRoles.PROXY_ADMIN.value | |
| with patch.dict(os.environ, {"PROXY_ADMIN_ID": test_user_id}): | |
| # Act | |
| updated_role = await check_and_update_if_proxy_admin_id( | |
| user_role=test_user_role, user_id=test_user_id, prisma_client=mock_prisma | |
| ) | |
| # Assert | |
| assert updated_role == LitellmUserRoles.PROXY_ADMIN.value | |
| mock_prisma.db.litellm_usertable.update.assert_not_called() | |