farmax commited on
Commit
6f3f570
·
verified ·
1 Parent(s): 629f6d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -98
app.py CHANGED
@@ -5,7 +5,6 @@ from typing import Dict, Any, List, Tuple
5
  from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AutoModelForCausalLM, pipeline
6
 
7
  # ================= CONFIGURAZIONE =================
8
-
9
  MODEL_DEBERTA = "osiria/deberta-italian-question-answering"
10
  MODEL_GEPPETTO = "LorenzoDeMattei/GePpeTto"
11
 
@@ -45,50 +44,25 @@ DOMANDE_IT: List[Tuple[str, str]] = [
45
  ("causale", "Qual è la causale della fattura? / Qual è la motivazione o descrizione del pagamento?")
46
  ]
47
 
48
- # ================= CACHE MODELLI =================
49
- LOADED: Dict[str, Any] = {}
50
-
51
- def get_deberta_pipeline():
52
- if "deb" in LOADED: return LOADED["deb"]
53
- tok = AutoTokenizer.from_pretrained(MODEL_DEBERTA)
54
- mdl = AutoModelForQuestionAnswering.from_pretrained(MODEL_DEBERTA)
55
- qa = pipeline("question-answering", model=mdl, tokenizer=tok, handle_impossible_answer=True, top_k=1, device=-1)
56
- LOADED["deb"] = qa
57
- return qa
58
-
59
- def get_geppetto_pipeline():
60
- if "gepp" in LOADED: return LOADED["gepp"]
61
- tok = AutoTokenizer.from_pretrained(MODEL_GEPPETTO)
62
- mdl = AutoModelForCausalLM.from_pretrained(MODEL_GEPPETTO)
63
- gen = pipeline("text-generation", model=mdl, tokenizer=tok, device=-1)
64
- LOADED["gepp"] = gen
65
- return gen
66
 
67
- # ================= UTILITY =================
 
 
68
 
 
69
  def preprocess_markdown(text: str) -> str:
70
  if not text: return ""
71
- text = re.sub(r'\|[\s-]+\|', ' ', text) # ripulisce separatori tabella
72
  text = text.replace('|', ' ')
73
  text = text.replace('**', '').replace('##', '')
74
- # mapping semantico leggero
75
- text = text.replace('P.IVA', 'partita IVA').replace('PIVA', 'partita IVA')
76
  text = re.sub(r'\s+', ' ', text).strip()
77
  return text
78
 
79
- def chunk_text(text: str, max_chars: int = 3000, overlap: int = 200) -> List[str]:
80
- if len(text) <= max_chars: return [text]
81
- chunks = []
82
- i = 0
83
- while i < len(text):
84
- end = min(i + max_chars, len(text))
85
- chunks.append(text[i:end])
86
- i = end - overlap
87
- if i < 0: i = 0
88
- return chunks
89
-
90
- # ================= LOGICA PRINCIPALE =================
91
-
92
  def analyze_invoice(md_text: str, custom_question_it: str):
93
  logs: List[str] = []
94
  final_output: Dict[str, Any] = {}
@@ -97,86 +71,72 @@ def analyze_invoice(md_text: str, custom_question_it: str):
97
  return {"Error": "Testo troppo breve"}, "⚠️ Inserisci almeno 10 caratteri."
98
 
99
  clean_text = preprocess_markdown(md_text)
100
- chunks = chunk_text(clean_text, max_chars=3000, overlap=200)
101
- logs.append(f"📄 Testo originale: {len(md_text)} chars | Pulito: {len(clean_text)} chars | Chunks: {len(chunks)}")
102
-
103
- qa_deb = get_deberta_pipeline()
104
- gen_gepp = get_geppetto_pipeline()
105
 
106
  # 1) DeBERTa: QA estrattivo su tutte le domande + opzionale
107
  t_start_deb = time.time()
108
  deb_res: Dict[str, Any] = {}
109
  success_count = 0
110
 
111
- def ask_all_chunks(question: str) -> Tuple[str, float]:
112
- best_answer, best_score = "", 0.0
113
- for c in chunks:
114
- try:
115
- r = qa_deb(question=question, context=c)
116
- ans = r.get("answer", "").strip()
117
- score = float(r.get("score", 0.0))
118
- if score > best_score and ans:
119
- best_answer, best_score = ans, score
120
- except Exception as e:
121
- logs.append(f"❌ Errore QA chunk: {str(e)}")
122
- return best_answer, best_score
123
-
124
  for key, question_text in DOMANDE_IT:
125
- answer, score = ask_all_chunks(question_text)
126
- status = "Successo" if score > 0.05 and answer else "Non Trovato"
127
- if status == "Successo": success_count += 1
128
- deb_res[key] = {
129
- "domanda": question_text,
130
- "risposta": answer,
131
- "confidenza": round(score, 3),
132
- "status": status
133
- }
 
 
 
 
 
134
 
135
  custom_q = custom_question_it.strip()
136
  if custom_q:
137
- answer, score = ask_all_chunks(custom_q)
138
- status = "Successo" if score > 0.05 and answer else "Non Trovato"
139
- if status == "Successo": success_count += 1
140
- deb_res["domanda_opzionale"] = {
141
- "domanda": custom_q,
142
- "risposta": answer,
143
- "confidenza": round(score, 3),
144
- "status": status
145
- }
 
 
 
 
 
146
 
147
  t_elapsed_deb = round(time.time() - t_start_deb, 2)
148
  final_output["DeBERTa (estrattivo)"] = deb_res
149
  logs.append(f"✅ DeBERTa completato in {t_elapsed_deb}s | Successi: {success_count}/{len(DOMANDE_IT) + (1 if custom_q else 0)}")
150
 
151
- # 2) GePpeTto: generativo su tutte le domande in blocco
152
- t_start_gepp = time.time()
153
- try:
154
- # Costruzione prompt conciso per ridurre rumore
155
- prompt_lines = ["Rispondi in elenco puntato alle seguenti domande sulla fattura:"]
156
- for _, q in DOMANDE_IT:
157
- prompt_lines.append(f"- {q}")
158
- if custom_q:
159
- prompt_lines.append(f"- {custom_q}")
160
- prompt_lines.append("\nContesto:")
161
- prompt_lines.append(clean_text[:4000]) # taglio prudenziale su CPU
162
- prompt_lines.append("\nRisposte (usa un punto per ogni domanda, senza inventare dati):")
163
-
164
- prompt = "\n".join(prompt_lines)
165
- gen = gen_gepp(prompt, max_new_tokens=256, do_sample=False)
166
- generative_text = gen[0]["generated_text"].replace(prompt, "").strip()
167
- final_output["GePpeTto (generativo)"] = {"risposte": generative_text}
168
- t_elapsed_gepp = round(time.time() - t_start_gepp, 2)
169
- logs.append(f"✅ GePpeTto completato in {t_elapsed_gepp}s")
170
- except Exception as e:
171
- final_output["GePpeTto (generativo)"] = {"errore": str(e)}
172
- logs.append(f"❌ Errore GePpeTto: {e}")
173
 
174
  return final_output, "\n".join(logs)
175
 
176
  # ================== UI GRADIO ==================
177
  with gr.Blocks(theme=gr.themes.Base()) as demo:
178
- gr.Markdown("# 🧾 Invoice QA: Domande standard + opzionale (DeBERTa estrattivo & GePpeTto generativo)")
179
- gr.Markdown("Risposte estrattive strutturate per tutte le domande e un blocco generativo riassuntivo, con log e tempi.")
180
 
181
  with gr.Row():
182
  with gr.Column(scale=1):
@@ -194,11 +154,11 @@ with gr.Blocks(theme=gr.themes.Base()) as demo:
194
  btn = gr.Button("🔍 Analizza documento", variant="primary")
195
 
196
  with gr.Column(scale=1):
197
- out_json = gr.JSON(label="Risultati estrattivi (DeBERTa) e generativi (GePpeTto)")
198
  with gr.Accordion("📝 Log di Sistema (Tempi e Debug)", open=False):
199
  out_log = gr.Textbox(label="Process Log", lines=12)
200
 
201
  btn.click(fn=analyze_invoice, inputs=[md_input, custom_q_input], outputs=[out_json, out_log])
202
 
203
  if __name__ == "__main__":
204
- demo.launch()
 
5
  from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AutoModelForCausalLM, pipeline
6
 
7
  # ================= CONFIGURAZIONE =================
 
8
  MODEL_DEBERTA = "osiria/deberta-italian-question-answering"
9
  MODEL_GEPPETTO = "LorenzoDeMattei/GePpeTto"
10
 
 
44
  ("causale", "Qual è la causale della fattura? / Qual è la motivazione o descrizione del pagamento?")
45
  ]
46
 
47
+ # ================= PIPELINES =================
48
+ tok_deb = AutoTokenizer.from_pretrained(MODEL_DEBERTA)
49
+ mdl_deb = AutoModelForQuestionAnswering.from_pretrained(MODEL_DEBERTA)
50
+ qa_deb = pipeline("question-answering", model=mdl_deb, tokenizer=tok_deb, device=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ tok_gepp = AutoTokenizer.from_pretrained(MODEL_GEPPETTO)
53
+ mdl_gepp = AutoModelForCausalLM.from_pretrained(MODEL_GEPPETTO)
54
+ qa_gepp = pipeline("text-generation", model=mdl_gepp, tokenizer=tok_gepp, device=-1)
55
 
56
+ # ================= UTILITY =================
57
  def preprocess_markdown(text: str) -> str:
58
  if not text: return ""
59
+ text = re.sub(r'\|[\s-]+\|', ' ', text)
60
  text = text.replace('|', ' ')
61
  text = text.replace('**', '').replace('##', '')
 
 
62
  text = re.sub(r'\s+', ' ', text).strip()
63
  return text
64
 
65
+ # ================= FUNZIONE PRINCIPALE =================
 
 
 
 
 
 
 
 
 
 
 
 
66
  def analyze_invoice(md_text: str, custom_question_it: str):
67
  logs: List[str] = []
68
  final_output: Dict[str, Any] = {}
 
71
  return {"Error": "Testo troppo breve"}, "⚠️ Inserisci almeno 10 caratteri."
72
 
73
  clean_text = preprocess_markdown(md_text)
 
 
 
 
 
74
 
75
  # 1) DeBERTa: QA estrattivo su tutte le domande + opzionale
76
  t_start_deb = time.time()
77
  deb_res: Dict[str, Any] = {}
78
  success_count = 0
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  for key, question_text in DOMANDE_IT:
81
+ try:
82
+ res = qa_deb(question=question_text, context=clean_text)
83
+ answer = res["answer"].strip()
84
+ score = round(res.get("score", 0.0), 3)
85
+ status = "Successo" if score > 0.05 and answer else "Non Trovato"
86
+ if status == "Successo": success_count += 1
87
+ deb_res[key] = {
88
+ "domanda": question_text,
89
+ "risposta": answer,
90
+ "confidenza": score,
91
+ "status": status
92
+ }
93
+ except Exception as e:
94
+ deb_res[key] = {"status": f"Errore inferenza: {str(e)}"}
95
 
96
  custom_q = custom_question_it.strip()
97
  if custom_q:
98
+ try:
99
+ res = qa_deb(question=custom_q, context=clean_text)
100
+ answer = res["answer"].strip()
101
+ score = round(res.get("score", 0.0), 3)
102
+ status = "Successo" if score > 0.05 and answer else "Non Trovato"
103
+ if status == "Successo": success_count += 1
104
+ deb_res["domanda_opzionale"] = {
105
+ "domanda": custom_q,
106
+ "risposta": answer,
107
+ "confidenza": score,
108
+ "status": status
109
+ }
110
+ except Exception as e:
111
+ deb_res["domanda_opzionale"] = {"status": f"Errore inferenza: {str(e)}"}
112
 
113
  t_elapsed_deb = round(time.time() - t_start_deb, 2)
114
  final_output["DeBERTa (estrattivo)"] = deb_res
115
  logs.append(f"✅ DeBERTa completato in {t_elapsed_deb}s | Successi: {success_count}/{len(DOMANDE_IT) + (1 if custom_q else 0)}")
116
 
117
+ # 2) GePpeTto: SOLO domanda opzionale
118
+ if custom_q:
119
+ try:
120
+ t_start_gepp = time.time()
121
+ short_context = clean_text[:800] # taglio prudenziale
122
+ prompt = f"Domanda: {custom_q}\nContesto: {short_context}\nRisposta:"
123
+ res_gepp = qa_gepp(prompt, max_new_tokens=64, do_sample=False)
124
+ generative_text = res_gepp[0]["generated_text"].replace(prompt, "").strip()
125
+ final_output["GePpeTto (generativo)"] = {"risposta_opzionale": generative_text}
126
+ t_elapsed_gepp = round(time.time() - t_start_gepp, 2)
127
+ logs.append(f"✅ GePpeTto completato in {t_elapsed_gepp}s")
128
+ except Exception as e:
129
+ final_output["GePpeTto (generativo)"] = {"errore": str(e)}
130
+ logs.append(f"❌ Errore GePpeTto: {e}")
131
+ else:
132
+ final_output["GePpeTto (generativo)"] = {"info": "Nessuna domanda opzionale fornita"}
 
 
 
 
 
 
133
 
134
  return final_output, "\n".join(logs)
135
 
136
  # ================== UI GRADIO ==================
137
  with gr.Blocks(theme=gr.themes.Base()) as demo:
138
+ gr.Markdown("# 🧾 Invoice QA: Domande standard + opzionale")
139
+ gr.Markdown("Risposte estrattive (DeBERTa) su tutte le domande e generative (GePpeTto) solo sulla domanda opzionale.")
140
 
141
  with gr.Row():
142
  with gr.Column(scale=1):
 
154
  btn = gr.Button("🔍 Analizza documento", variant="primary")
155
 
156
  with gr.Column(scale=1):
157
+ out_json = gr.JSON(label="Risultati (DeBERTa estrattivo + GePpeTto opzionale)")
158
  with gr.Accordion("📝 Log di Sistema (Tempi e Debug)", open=False):
159
  out_log = gr.Textbox(label="Process Log", lines=12)
160
 
161
  btn.click(fn=analyze_invoice, inputs=[md_input, custom_q_input], outputs=[out_json, out_log])
162
 
163
  if __name__ == "__main__":
164
+ demo.launch()