Spaces:
Sleeping
Sleeping
| # src/analyzer.py | |
| from typing import Dict, List, Any, Optional, Union | |
| import asyncio | |
| from concurrent.futures import ThreadPoolExecutor | |
| from transformers import pipeline | |
| from datetime import datetime | |
| from .ontology import OntologyRegistry | |
| from .relationships import RelationshipEngine | |
| class EventAnalyzer: | |
| """Main analyzer class for event processing.""" | |
| def __init__(self) -> None: | |
| """Initialize the event analyzer with required components.""" | |
| self.ontology = OntologyRegistry() | |
| self.relationship_engine = RelationshipEngine() | |
| self.executor = ThreadPoolExecutor(max_workers=3) | |
| # Initialize NLP pipelines | |
| self.ner_pipeline = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english") | |
| self.classifier = pipeline("zero-shot-classification") | |
| async def extract_entities(self, text: str) -> Dict[str, List[str]]: | |
| """Extract entities from text using NER pipeline.""" | |
| def _extract(): | |
| return self.ner_pipeline(text) | |
| ner_results = await asyncio.get_event_loop().run_in_executor( | |
| self.executor, _extract | |
| ) | |
| entities = { | |
| "people": [], | |
| "organizations": [], | |
| "locations": [], | |
| "hashtags": [word for word in text.split() if word.startswith('#')] | |
| } | |
| for item in ner_results: | |
| if item["entity"].endswith("PER"): | |
| entities["people"].append(item["word"]) | |
| elif item["entity"].endswith("ORG"): | |
| entities["organizations"].append(item["word"]) | |
| elif item["entity"].endswith("LOC"): | |
| entities["locations"].append(item["word"]) | |
| return entities | |
| def extract_temporal(self, text: str) -> List[str]: | |
| """Extract temporal expressions from text.""" | |
| return self.ontology.validate_pattern(text, 'temporal') | |
| async def extract_locations(self, text: str) -> List[str]: | |
| """Extract locations using both NER and pattern matching.""" | |
| entities = await self.extract_entities(text) | |
| ml_locations = entities.get('locations', []) | |
| pattern_locations = self.ontology.validate_pattern(text, 'location') | |
| return list(set(ml_locations + pattern_locations)) | |
| def calculate_confidence(self, | |
| entities: Dict[str, List[str]], | |
| temporal_data: List[str], | |
| related_events: List[Any]) -> float: | |
| """Calculate confidence score for extracted information.""" | |
| # Base confidence from entity presence | |
| base_confidence = min(1.0, ( | |
| 0.2 * bool(entities["people"]) + | |
| 0.2 * bool(entities["organizations"]) + | |
| 0.3 * bool(entities["locations"]) + | |
| 0.3 * bool(temporal_data) | |
| )) | |
| # Get entity parameters for frequency calculation | |
| entity_params = [ | |
| *entities["people"], | |
| *entities["organizations"], | |
| *entities["locations"] | |
| ] | |
| if not entity_params: | |
| return base_confidence | |
| # Calculate entity frequency boost | |
| query = f''' | |
| SELECT AVG(frequency) as avg_freq | |
| FROM entities | |
| WHERE entity_text IN ({','.join(['?']*len(entity_params))}) | |
| ''' | |
| cursor = self.relationship_engine.conn.execute(query, entity_params) | |
| avg_frequency = cursor.fetchone()[0] or 1 | |
| frequency_boost = min(0.2, (avg_frequency - 1) * 0.05) | |
| # Calculate relationship confidence boost | |
| relationship_confidence = 0 | |
| if related_events: | |
| relationship_scores = [] | |
| for event in related_events: | |
| cursor = self.relationship_engine.conn.execute(''' | |
| SELECT COUNT(*) as shared_entities | |
| FROM event_entities ee1 | |
| JOIN event_entities ee2 ON ee1.entity_id = ee2.entity_id | |
| WHERE ee1.event_id = ? AND ee2.event_id = ? | |
| ''', (event[0], event[0])) | |
| shared_count = cursor.fetchone()[0] | |
| relationship_scores.append(min(0.3, shared_count * 0.1)) | |
| if relationship_scores: | |
| relationship_confidence = max(relationship_scores) | |
| return min(1.0, base_confidence + frequency_boost + relationship_confidence) | |
| async def analyze_event(self, text: str) -> Dict[str, Any]: | |
| """Analyze event text and extract structured information.""" | |
| try: | |
| # Parallel extraction | |
| entities_future = self.extract_entities(text) | |
| temporal_data = self.extract_temporal(text) | |
| locations_future = self.extract_locations(text) | |
| # Gather async results | |
| entities, locations = await asyncio.gather( | |
| entities_future, locations_future | |
| ) | |
| # Merge locations and add temporal data | |
| entities['locations'] = locations | |
| entities['temporal'] = temporal_data | |
| # Find related events | |
| related_events = self.relationship_engine.find_related_events({ | |
| 'text': text, | |
| 'entities': entities | |
| }) | |
| # Calculate confidence | |
| confidence = self.calculate_confidence(entities, temporal_data, related_events) | |
| # Store event if confidence meets threshold | |
| cursor = None | |
| if confidence >= 0.6: | |
| cursor = self.relationship_engine.conn.execute( | |
| 'INSERT INTO events (text, timestamp, confidence) VALUES (?, ?, ?)', | |
| (text, datetime.now().isoformat(), confidence) | |
| ) | |
| event_id = cursor.lastrowid | |
| # Store entities and update relationships | |
| self.relationship_engine.store_entities(event_id, { | |
| 'person': entities['people'], | |
| 'organization': entities['organizations'], | |
| 'location': entities['locations'], | |
| 'temporal': temporal_data, | |
| 'hashtag': entities['hashtags'] | |
| }) | |
| self.relationship_engine.update_entity_relationships(event_id) | |
| self.relationship_engine.conn.commit() | |
| # Get entity relationships for output | |
| entity_relationships = [] | |
| if cursor and cursor.lastrowid: | |
| entity_relationships = self.relationship_engine.get_entity_relationships(cursor.lastrowid) | |
| return { | |
| "text": text, | |
| "entities": entities, | |
| "confidence": confidence, | |
| "verification_needed": confidence < 0.6, | |
| "related_events": [ | |
| { | |
| "text": event[1], | |
| "timestamp": event[2], | |
| "confidence": event[3], | |
| "shared_entities": event[4] if len(event) > 4 else None | |
| } | |
| for event in related_events | |
| ], | |
| "entity_relationships": entity_relationships | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def get_entity_statistics(self) -> Dict[str, List[tuple]]: | |
| """Get statistics about stored entities and relationships.""" | |
| stats = {} | |
| # Entity counts by type | |
| cursor = self.relationship_engine.conn.execute(''' | |
| SELECT entity_type, COUNT(*) as count, AVG(frequency) as avg_frequency | |
| FROM entities | |
| GROUP BY entity_type | |
| ''') | |
| stats['entity_counts'] = cursor.fetchall() | |
| # Most frequent entities | |
| cursor = self.relationship_engine.conn.execute(''' | |
| SELECT entity_text, entity_type, frequency | |
| FROM entities | |
| ORDER BY frequency DESC | |
| LIMIT 10 | |
| ''') | |
| stats['frequent_entities'] = cursor.fetchall() | |
| # Relationship statistics | |
| cursor = self.relationship_engine.conn.execute(''' | |
| SELECT relationship_type, COUNT(*) as count, AVG(confidence) as avg_confidence | |
| FROM entity_relationships | |
| GROUP BY relationship_type | |
| ''') | |
| stats['relationship_stats'] = cursor.fetchall() | |
| return stats |