DungeonMaster-AI / src /game /game_state_manager.py
bhupesh-sf's picture
first commit
f8ba6bf verified
"""
DungeonMaster AI - Game State Manager
High-level manager for game state, integrating MCP tools,
character caching, combat tracking, and save/load functionality.
"""
from __future__ import annotations
import json
import logging
import uuid
from datetime import datetime, timedelta
from pathlib import Path
from typing import TYPE_CHECKING
from .event_logger import EventLogger
from .models import (
CharacterSnapshot,
CombatantStatus,
Combatant,
CombatState,
EventType,
GameSaveData,
NPCInfo,
SceneInfo,
SessionEvent,
)
if TYPE_CHECKING:
from src.mcp_integration.toolkit_client import TTRPGToolkitClient
logger = logging.getLogger(__name__)
class GameStateManager:
"""
High-level manager for game state.
Orchestrates MCP tool integration, character caching, combat state,
event logging, and save/load functionality. This is the main interface
for game state management used by the agent orchestrator.
"""
def __init__(
self,
toolkit_client: TTRPGToolkitClient | None = None,
max_recent_events: int = 50,
character_cache_ttl: int = 300, # 5 minutes
) -> None:
"""
Initialize the game state manager.
Args:
toolkit_client: Optional MCP toolkit client for remote operations
max_recent_events: Maximum events to keep in memory
character_cache_ttl: Character cache TTL in seconds
"""
self._toolkit_client = toolkit_client
self._max_recent_events = max_recent_events
self._character_cache_ttl = character_cache_ttl
# Session state
self._session_id = str(uuid.uuid4())
self._mcp_session_id: str | None = None
self._started_at = datetime.now()
self._turn_count = 0
# Party state
self._party_ids: list[str] = []
self._active_character_id: str | None = None
# Location state
self._current_location = "Unknown"
self._current_scene: SceneInfo | None = None
# Combat state
self._in_combat = False
self._combat_state: CombatState | None = None
# Tracking
self._story_flags: dict[str, object] = {}
self._known_npcs: dict[str, NPCInfo] = {}
# Caching
self._character_cache: dict[str, CharacterSnapshot] = {}
self._character_cache_times: dict[str, datetime] = {}
# Adventure
self._adventure_name: str | None = None
# Event logger
self._event_logger = EventLogger(
toolkit_client=toolkit_client,
max_events=max_recent_events,
)
logger.debug(f"GameStateManager initialized with session: {self._session_id}")
# =========================================================================
# Properties
# =========================================================================
@property
def session_id(self) -> str:
"""Get the current session ID."""
return self._session_id
@property
def mcp_session_id(self) -> str | None:
"""Get the MCP session ID if connected."""
return self._mcp_session_id
@property
def turn_count(self) -> int:
"""Get the current turn count."""
return self._turn_count
@property
def party_ids(self) -> list[str]:
"""Get list of character IDs in the party."""
return self._party_ids.copy()
@property
def active_character_id(self) -> str | None:
"""Get the active character ID."""
return self._active_character_id
@property
def current_location(self) -> str:
"""Get the current location name."""
return self._current_location
@property
def current_scene(self) -> SceneInfo | None:
"""Get the current scene info."""
return self._current_scene
@property
def in_combat(self) -> bool:
"""Check if combat is active."""
return self._in_combat
@property
def combat_state(self) -> CombatState | None:
"""Get the current combat state."""
return self._combat_state
@property
def adventure_name(self) -> str | None:
"""Get the loaded adventure name."""
return self._adventure_name
@property
def story_flags(self) -> dict[str, object]:
"""Get all story flags."""
return self._story_flags.copy()
@property
def event_logger(self) -> EventLogger:
"""Get the event logger."""
return self._event_logger
@property
def recent_events(self) -> list[SessionEvent]:
"""Get recent events from the logger."""
return self._event_logger.events
# =========================================================================
# Session Lifecycle
# =========================================================================
async def new_game(self, adventure: str | None = None) -> str:
"""
Start a new game session.
Resets all state and optionally starts an MCP session.
Args:
adventure: Optional adventure name
Returns:
The new session ID
"""
# Reset state
self._session_id = str(uuid.uuid4())
self._started_at = datetime.now()
self._turn_count = 0
self._party_ids.clear()
self._active_character_id = None
self._current_location = "Unknown"
self._current_scene = None
self._in_combat = False
self._combat_state = None
self._story_flags.clear()
self._known_npcs.clear()
self._character_cache.clear()
self._character_cache_times.clear()
self._adventure_name = adventure
# Clear event logger
self._event_logger.clear()
self._event_logger.set_current_turn(0)
# Start MCP session if connected
if self._toolkit_client and self._toolkit_client.is_connected:
try:
result = await self._toolkit_client.call_tool(
"mcp_start_session",
{
"campaign_name": adventure or "DungeonMaster AI Session",
"system": "dnd5e",
},
)
if isinstance(result, dict):
self._mcp_session_id = str(result.get("session_id", ""))
self._event_logger.set_mcp_session_id(self._mcp_session_id)
logger.info(f"Started MCP session: {self._mcp_session_id}")
except Exception as e:
logger.warning(f"Failed to start MCP session: {e}")
self._mcp_session_id = None
# Log system event
self._event_logger.log_system(
f"New game started" + (f": {adventure}" if adventure else ""),
{"adventure": adventure},
)
logger.info(f"New game started with session: {self._session_id}")
return self._session_id
async def end_game(self) -> None:
"""
End the current game session.
Ends the MCP session if connected and logs the event.
"""
# End MCP session if active
if self._toolkit_client and self._mcp_session_id:
try:
await self._toolkit_client.call_tool(
"mcp_end_session",
{"session_id": self._mcp_session_id},
)
logger.info(f"Ended MCP session: {self._mcp_session_id}")
except Exception as e:
logger.warning(f"Failed to end MCP session: {e}")
# Log event
self._event_logger.log_system(
"Game ended",
{"turns": self._turn_count, "duration_minutes": self._get_session_duration()},
)
self._mcp_session_id = None
def _get_session_duration(self) -> int:
"""Get session duration in minutes."""
delta = datetime.now() - self._started_at
return int(delta.total_seconds() / 60)
# =========================================================================
# Character Management
# =========================================================================
async def add_character(self, character_id: str) -> CharacterSnapshot | None:
"""
Add a character to the party.
Fetches character data from MCP and caches it.
Args:
character_id: ID of character to add
Returns:
CharacterSnapshot if successful, None otherwise
"""
# Check if already in party
if character_id in self._party_ids:
return self._character_cache.get(character_id)
# Fetch from MCP
snapshot = await self._fetch_character(character_id)
if snapshot is None:
logger.warning(f"Failed to fetch character: {character_id}")
return None
# Add to party
self._party_ids.append(character_id)
# Set as active if first character
if self._active_character_id is None:
self._active_character_id = character_id
# Log event
self._event_logger.log_system(
f"{snapshot.name} joined the party",
{"character_id": character_id, "name": snapshot.name},
)
logger.info(f"Added character to party: {snapshot.name} ({character_id})")
return snapshot
async def _fetch_character(self, character_id: str) -> CharacterSnapshot | None:
"""
Fetch character data from MCP and cache it.
Args:
character_id: Character ID to fetch
Returns:
CharacterSnapshot if successful, None otherwise
"""
if not self._toolkit_client or not self._toolkit_client.is_connected:
# Return cached if available
return self._character_cache.get(character_id)
try:
result = await self._toolkit_client.call_tool(
"mcp_get_character",
{"character_id": character_id},
)
if not isinstance(result, dict):
return None
if not result.get("success", False):
return None
snapshot = CharacterSnapshot.from_mcp_result(result)
# Cache the result
self._character_cache[character_id] = snapshot
self._character_cache_times[character_id] = datetime.now()
return snapshot
except Exception as e:
logger.warning(f"Failed to fetch character {character_id}: {e}")
return self._character_cache.get(character_id)
async def get_active_character(self) -> CharacterSnapshot | None:
"""
Get the active character's data.
Refreshes from MCP if cache is stale.
Returns:
CharacterSnapshot if available, None otherwise
"""
if self._active_character_id is None:
return None
# Check cache freshness
if self._is_cache_stale(self._active_character_id):
await self.refresh_character(self._active_character_id)
return self._character_cache.get(self._active_character_id)
def _is_cache_stale(self, character_id: str) -> bool:
"""Check if a character's cache is stale."""
cache_time = self._character_cache_times.get(character_id)
if cache_time is None:
return True
age = datetime.now() - cache_time
return age.total_seconds() > self._character_cache_ttl
async def refresh_character(
self,
character_id: str,
) -> CharacterSnapshot | None:
"""
Force refresh a character's data from MCP.
Args:
character_id: Character to refresh
Returns:
Updated CharacterSnapshot if successful
"""
return await self._fetch_character(character_id)
def set_active_character(self, character_id: str) -> bool:
"""
Set the active character.
Args:
character_id: Character ID to make active
Returns:
True if successful, False if not in party
"""
if character_id not in self._party_ids:
return False
self._active_character_id = character_id
return True
def remove_character(self, character_id: str) -> None:
"""
Remove a character from the party.
Args:
character_id: Character to remove
"""
if character_id in self._party_ids:
self._party_ids.remove(character_id)
# Clear active if removed
if self._active_character_id == character_id:
self._active_character_id = (
self._party_ids[0] if self._party_ids else None
)
# Clear from cache
self._character_cache.pop(character_id, None)
self._character_cache_times.pop(character_id, None)
def get_character_snapshot(
self,
character_id: str,
) -> CharacterSnapshot | None:
"""
Get a character's cached snapshot.
Args:
character_id: Character ID
Returns:
CharacterSnapshot if cached, None otherwise
"""
return self._character_cache.get(character_id)
def get_party_snapshots(self) -> list[CharacterSnapshot]:
"""
Get all party members' cached snapshots.
Returns:
List of CharacterSnapshot objects
"""
return [
self._character_cache[cid]
for cid in self._party_ids
if cid in self._character_cache
]
# =========================================================================
# Tool Result Processing
# =========================================================================
async def update_from_tool_calls(
self,
tool_results: list[dict[str, object]],
) -> None:
"""
Update state based on MCP tool call results.
Args:
tool_results: List of {tool_name, result} dicts
"""
for entry in tool_results:
tool_name = str(entry.get("tool_name", ""))
result = entry.get("result", {})
if not isinstance(result, dict):
continue
# Dispatch to appropriate handler
if "modify_hp" in tool_name:
await self._process_hp_change(result)
elif "start_combat" in tool_name:
await self._process_combat_start(result)
elif "end_combat" in tool_name:
await self._process_combat_end(result)
elif "next_turn" in tool_name:
await self._process_next_turn(result)
elif "add_condition" in tool_name:
await self._process_condition_change(result, added=True)
elif "remove_condition" in tool_name:
await self._process_condition_change(result, added=False)
elif "rest" in tool_name:
await self._process_rest(result)
async def _process_hp_change(self, result: dict[str, object]) -> None:
"""Process HP modification result."""
character_id = str(result.get("character_id", ""))
if not character_id:
return
# Update cache
new_hp = int(result.get("new_hp", result.get("current_hp", 0)))
max_hp = int(result.get("max_hp", 1))
previous_hp = int(result.get("previous_hp", 0))
if character_id in self._character_cache:
snapshot = self._character_cache[character_id]
updated_data = snapshot.model_dump()
updated_data["hp_current"] = new_hp
updated_data["hp_max"] = max_hp
updated_data["cached_at"] = datetime.now()
self._character_cache[character_id] = CharacterSnapshot.model_validate(
updated_data
)
# Check for death
if new_hp <= 0 and previous_hp > 0:
name = str(result.get("name", "Character"))
is_damage = result.get("is_damage", True)
if is_damage:
self._event_logger.log_damage(
character_name=name,
amount=previous_hp - new_hp,
damage_type=str(result.get("damage_type", "untyped")),
source="unknown",
is_lethal=True,
)
# Update combat state if applicable
if self._combat_state:
combatant = self._combat_state.get_combatant(character_id)
if combatant:
self._combat_state.update_combatant(
character_id,
hp_current=new_hp,
status=CombatantStatus.UNCONSCIOUS,
)
async def _process_combat_start(self, result: dict[str, object]) -> None:
"""Process combat start result."""
self._in_combat = True
# Build combatant list
combatants: list[Combatant] = []
turn_order = result.get("turn_order", [])
if isinstance(turn_order, list):
for i, entry in enumerate(turn_order):
if isinstance(entry, dict):
combatants.append(
Combatant(
combatant_id=str(entry.get("id", str(uuid.uuid4()))),
name=str(entry.get("name", f"Combatant {i + 1}")),
initiative=int(entry.get("initiative", 0)),
is_player=bool(entry.get("is_player", False)),
hp_current=int(entry.get("hp_current", 10)),
hp_max=int(entry.get("hp_max", 10)),
armor_class=int(entry.get("ac", 10)),
conditions=list(entry.get("conditions", [])),
status=CombatantStatus.ACTIVE,
)
)
self._combat_state = CombatState(
combat_id=str(result.get("combat_id", str(uuid.uuid4()))),
round_number=1,
turn_index=0,
combatants=combatants,
started_at=datetime.now(),
)
# Log event
self._event_logger.log_combat_start(
description=str(result.get("description", "Combat began!")),
combatants=[c.name for c in combatants],
)
async def _process_combat_end(self, result: dict[str, object]) -> None:
"""Process combat end result."""
outcome = str(result.get("outcome", "victory"))
self._event_logger.log_combat_end(
outcome=outcome,
description=str(result.get("description", "")),
)
self._in_combat = False
self._combat_state = None
async def _process_next_turn(self, result: dict[str, object]) -> None:
"""Process next turn result."""
if not self._combat_state:
return
# Advance turn
new_combatant = self._combat_state.advance_turn()
if new_combatant:
# Update from result if provided
turn_index = result.get("turn_index")
if isinstance(turn_index, int):
self._combat_state.turn_index = turn_index
round_number = result.get("round_number")
if isinstance(round_number, int):
self._combat_state.round_number = round_number
async def _process_condition_change(
self,
result: dict[str, object],
added: bool,
) -> None:
"""Process condition add/remove result."""
character_id = str(result.get("character_id", ""))
condition = str(result.get("condition", ""))
if not character_id or not condition:
return
# Update character cache
if character_id in self._character_cache:
snapshot = self._character_cache[character_id]
conditions = snapshot.conditions.copy()
if added and condition not in conditions:
conditions.append(condition)
elif not added and condition in conditions:
conditions.remove(condition)
updated_data = snapshot.model_dump()
updated_data["conditions"] = conditions
updated_data["cached_at"] = datetime.now()
self._character_cache[character_id] = CharacterSnapshot.model_validate(
updated_data
)
# Update combat state if applicable
if self._combat_state:
combatant = self._combat_state.get_combatant(character_id)
if combatant:
conditions = combatant.conditions.copy()
if added and condition not in conditions:
conditions.append(condition)
elif not added and condition in conditions:
conditions.remove(condition)
self._combat_state.update_combatant(
character_id, conditions=conditions
)
async def _process_rest(self, result: dict[str, object]) -> None:
"""Process rest result."""
character_id = str(result.get("character_id", ""))
rest_type = str(result.get("rest_type", "short"))
hp_recovered = int(result.get("hp_recovered", 0))
# Refresh character from MCP
if character_id:
await self.refresh_character(character_id)
# Log event
name = str(result.get("name", "Character"))
self._event_logger.log_rest(
rest_type=rest_type,
character_name=name,
hp_recovered=hp_recovered,
)
# =========================================================================
# Event Management
# =========================================================================
def add_event(
self,
event_type: EventType,
description: str,
data: dict[str, object] | None = None,
is_significant: bool = False,
) -> SessionEvent:
"""
Add a game event.
Args:
event_type: Type of event
description: Human-readable description
data: Event-specific data
is_significant: Whether event is significant
Returns:
Created SessionEvent
"""
return self._event_logger._create_event(
event_type=event_type,
description=description,
data=data,
is_significant=is_significant,
)
# =========================================================================
# Location Management
# =========================================================================
def set_location(
self,
location: str,
scene: SceneInfo | None = None,
) -> None:
"""
Update the current location.
Args:
location: Location name
scene: Optional scene info
"""
old_location = self._current_location
self._current_location = location
self._current_scene = scene
# Log movement if location changed
if old_location != location and old_location != "Unknown":
self._event_logger.log_movement(old_location, location)
def add_known_npc(self, npc: NPCInfo) -> None:
"""
Add or update a known NPC.
Args:
npc: NPC info to add
"""
self._known_npcs[npc.npc_id] = npc
def get_npc(self, npc_id: str) -> NPCInfo | None:
"""
Get a known NPC by ID.
Args:
npc_id: NPC ID
Returns:
NPCInfo if found, None otherwise
"""
return self._known_npcs.get(npc_id)
def get_npcs_in_scene(self) -> list[NPCInfo]:
"""
Get NPCs present in the current scene.
Returns:
List of NPCInfo objects
"""
if not self._current_scene:
return []
return [
self._known_npcs[npc_id]
for npc_id in self._current_scene.npcs_present
if npc_id in self._known_npcs
]
def set_story_flag(self, flag: str, value: object) -> None:
"""
Set a story/quest flag.
Args:
flag: Flag name
value: Flag value
"""
self._story_flags[flag] = value
# Log significant flags
self._event_logger.log_story_flag(flag, value)
def get_story_flag(self, flag: str, default: object = None) -> object:
"""
Get a story/quest flag.
Args:
flag: Flag name
default: Default value
Returns:
Flag value or default
"""
return self._story_flags.get(flag, default)
# =========================================================================
# Turn Management
# =========================================================================
def increment_turn(self) -> int:
"""
Increment the turn counter.
Returns:
New turn count
"""
self._turn_count += 1
self._event_logger.set_current_turn(self._turn_count)
return self._turn_count
# =========================================================================
# Save/Load
# =========================================================================
async def save(
self,
file_path: Path | str | None = None,
conversation_history: list[dict[str, object]] | None = None,
) -> GameSaveData:
"""
Save the current game state.
Args:
file_path: Optional path to save JSON file
conversation_history: Optional chat history to include
Returns:
GameSaveData object
"""
# Refresh all character caches before saving
for character_id in self._party_ids:
await self.refresh_character(character_id)
# Build save data
save_data = GameSaveData(
version="1.0.0",
saved_at=datetime.now(),
session_id=self._session_id,
turn_count=self._turn_count,
party_ids=self._party_ids.copy(),
active_character_id=self._active_character_id,
character_snapshots=list(self._character_cache.values()),
current_location=self._current_location,
current_scene=self._current_scene,
in_combat=self._in_combat,
combat_state=self._combat_state,
story_flags=self._story_flags.copy(),
known_npcs=self._known_npcs.copy(),
recent_events=self._event_logger.events.copy(),
adventure_name=self._adventure_name,
conversation_history=conversation_history or [],
)
# Write to file if path provided
if file_path:
path = Path(file_path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
f.write(save_data.model_dump_json(indent=2))
logger.info(f"Game saved to: {path}")
return save_data
async def load(self, file_path: Path | str) -> bool:
"""
Load a saved game.
Args:
file_path: Path to save file
Returns:
True if successful, False otherwise
"""
try:
path = Path(file_path)
with open(path, encoding="utf-8") as f:
data = json.load(f)
save_data = GameSaveData.model_validate(data)
# Restore state
self._session_id = save_data.session_id
self._turn_count = save_data.turn_count
self._party_ids = save_data.party_ids.copy()
self._active_character_id = save_data.active_character_id
self._current_location = save_data.current_location
self._current_scene = save_data.current_scene
self._in_combat = save_data.in_combat
self._combat_state = save_data.combat_state
self._story_flags = dict(save_data.story_flags)
self._known_npcs = dict(save_data.known_npcs)
self._adventure_name = save_data.adventure_name
# Restore character cache
self._character_cache.clear()
self._character_cache_times.clear()
for snapshot in save_data.character_snapshots:
self._character_cache[snapshot.character_id] = snapshot
self._character_cache_times[snapshot.character_id] = datetime.now()
# Restore events
self._event_logger.clear()
for event in save_data.recent_events:
self._event_logger._events.append(event)
self._event_logger.set_current_turn(self._turn_count)
# Start new MCP session for loaded game
if self._toolkit_client and self._toolkit_client.is_connected:
try:
result = await self._toolkit_client.call_tool(
"mcp_start_session",
{
"campaign_name": f"Loaded: {self._adventure_name or 'Session'}",
"system": "dnd5e",
},
)
if isinstance(result, dict):
self._mcp_session_id = str(result.get("session_id", ""))
self._event_logger.set_mcp_session_id(self._mcp_session_id)
except Exception as e:
logger.warning(f"Failed to start MCP session for loaded game: {e}")
# Log event
self._event_logger.log_system(
"Game loaded",
{"loaded_from": str(path)},
)
logger.info(f"Game loaded from: {path}")
return True
except Exception as e:
logger.error(f"Failed to load game: {e}")
return False
def export_for_download(
self,
conversation_history: list[dict[str, object]] | None = None,
) -> str:
"""
Export game state as JSON string for browser download.
Args:
conversation_history: Optional chat history to include
Returns:
JSON string
"""
save_data = GameSaveData(
version="1.0.0",
saved_at=datetime.now(),
session_id=self._session_id,
turn_count=self._turn_count,
party_ids=self._party_ids.copy(),
active_character_id=self._active_character_id,
character_snapshots=list(self._character_cache.values()),
current_location=self._current_location,
current_scene=self._current_scene,
in_combat=self._in_combat,
combat_state=self._combat_state,
story_flags=self._story_flags.copy(),
known_npcs=self._known_npcs.copy(),
recent_events=self._event_logger.events.copy(),
adventure_name=self._adventure_name,
conversation_history=conversation_history or [],
)
return save_data.model_dump_json(indent=2)
# =========================================================================
# Utilities
# =========================================================================
def set_toolkit_client(self, client: TTRPGToolkitClient | None) -> None:
"""
Set or update the toolkit client.
Args:
client: New toolkit client
"""
self._toolkit_client = client
self._event_logger.set_toolkit_client(client)
def to_summary(self) -> dict[str, object]:
"""
Create a summary dict for quick state overview.
Returns:
Summary dict
"""
return {
"session_id": self._session_id,
"turn_count": self._turn_count,
"party_size": len(self._party_ids),
"active_character": self._active_character_id,
"location": self._current_location,
"in_combat": self._in_combat,
"combat_round": self._combat_state.round_number
if self._combat_state
else None,
"adventure": self._adventure_name,
"event_count": len(self._event_logger),
}
def __repr__(self) -> str:
"""String representation."""
return (
f"GameStateManager("
f"session={self._session_id[:8]}..., "
f"turn={self._turn_count}, "
f"party={len(self._party_ids)}, "
f"combat={self._in_combat})"
)