jenniferhk008 commited on
Commit
c2742ac
·
verified ·
1 Parent(s): 5e57e5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -145
app.py CHANGED
@@ -1,149 +1,197 @@
1
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
2
- import torch
3
  import streamlit as st
4
- from PIL import Image
5
- import pytesseract
6
- import openai
7
  import pandas as pd
8
- import plotly.express as px
9
-
10
- # Step 1: Emoji 翻译模型(你自己训练的模型)
11
- emoji_model_id = "JenniferHJF/qwen1.5-emoji-finetuned"
12
- emoji_tokenizer = AutoTokenizer.from_pretrained(emoji_model_id, trust_remote_code=True)
13
- emoji_model = AutoModelForCausalLM.from_pretrained(
14
- emoji_model_id,
15
- trust_remote_code=True,
16
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
17
- ).to("cuda" if torch.cuda.is_available() else "cpu")
18
- emoji_model.eval()
19
-
20
- # ✅ Step 2: 可选择的冒犯性文本识别模型
21
- model_options = {
22
- "Toxic-BERT": "unitary/toxic-bert",
23
- "Roberta Offensive": "cardiffnlp/twitter-roberta-base-offensive",
24
- "BERT Emotion": "bhadresh-savani/bert-base-go-emotion"
 
25
  }
26
-
27
- # 页面配置
28
- st.set_page_config(page_title="Emoji Offensive Text Detector", page_icon="🚨", layout="wide")
29
-
30
- # ✅ 页面布局
31
- with st.sidebar:
32
- st.header("🧠 Navigation")
33
- section = st.radio("Select Mode:", ["📍 Text Moderation", "📊 Text Analysis", "🛠️ Agent Build"])
34
-
35
- if section == "📍 Text Moderation":
36
- moderation_type = st.selectbox("Select Task Type", ["Normal Text", "Bullet Screen Text"])
37
- selected_model = st.selectbox("Choose classification model", list(model_options.keys()))
38
- selected_model_id = model_options[selected_model]
39
- classifier = pipeline("text-classification", model=selected_model_id, device=0 if torch.cuda.is_available() else -1)
40
-
41
- elif section == "📊 Text Analysis":
42
- st.markdown("You can view the violation distribution chart and editing suggestions.")
43
-
44
- elif section == "🛠️ Agent Build":
45
- st.markdown("Upload supporting files for in-context fine-tuning.")
46
- uploaded_reference = st.file_uploader("Upload fine-tuning reference file", type=["txt", "csv"])
47
-
48
- if "history" not in st.session_state:
49
- st.session_state.history = []
50
-
51
-
52
- def classify_emoji_text(text: str):
53
- prompt = f"输入:{text}\n输出:"
54
- input_ids = emoji_tokenizer(prompt, return_tensors="pt").to(emoji_model.device)
55
- with torch.no_grad():
56
- output_ids = emoji_model.generate(**input_ids, max_new_tokens=64, do_sample=False)
57
- decoded = emoji_tokenizer.decode(output_ids[0], skip_special_tokens=True)
58
- translated_text = decoded.split("输出:")[-1].strip() if "输出:" in decoded else decoded.strip()
59
-
60
- result = classifier(translated_text)[0]
61
- label = result["label"]
62
- score = result["score"]
63
- reasoning = f"The sentence was flagged as '{label}' due to potentially offensive phrases. Consider replacing emotionally charged, ambiguous, or abusive terms."
64
-
65
- st.session_state.history.append({"text": text, "translated": translated_text, "label": label, "score": score, "reason": reasoning})
66
- return translated_text, label, score, reasoning
67
-
68
- # ✅ Section logic
69
- if section == "📍 Text Moderation":
70
- st.title("📍 Offensive Text Classification")
71
- st.markdown("### ✍️ Input your sentence:")
72
- default_text = "你是🐷"
73
- text = st.text_area("Enter sentence with emojis:", value=default_text, height=150)
74
-
75
- if st.button("🚦 Analyze"):
76
- with st.spinner("🔍 Processing..."):
77
- try:
78
- translated, label, score, reason = classify_emoji_text(text)
79
- st.markdown("### 🔄 Translated sentence:")
80
- st.code(translated, language="text")
81
-
82
- st.markdown(f"### 🎯 Prediction: `{label}`")
83
- st.markdown(f"### 📊 Confidence Score: `{score:.2%}`")
84
- st.markdown(f"### 🧠 Model Explanation:")
85
- st.info(reason)
86
-
87
- except Exception as e:
88
- st.error(f" An error occurred during processing:\n\n{e}")
89
-
90
- st.markdown("---")
91
- st.markdown("### 🖼️ Or upload a screenshot of bullet comments:")
92
-
93
- uploaded_file = st.file_uploader("Upload an image (JPG/PNG)", type=["jpg", "jpeg", "png"])
94
-
95
- if uploaded_file is not None:
96
- image = Image.open(uploaded_file)
97
- st.image(image, caption="Uploaded Screenshot", use_column_width=True)
98
-
99
- with st.spinner("🧠 Extracting text via OCR..."):
100
- ocr_text = pytesseract.image_to_string(image, lang="chi_sim+eng")
101
- st.markdown("#### 📋 Extracted Text:")
102
- st.code(ocr_text.strip())
103
-
104
- translated, label, score, reason = classify_emoji_text(ocr_text.strip())
105
- st.markdown("### 🔄 Translated sentence:")
106
- st.code(translated, language="text")
107
-
108
- st.markdown(f"### 🎯 Prediction: `{label}`")
109
- st.markdown(f"### 📊 Confidence Score: `{score:.2%}`")
110
- st.markdown("### 🧠 Model Explanation:")
111
- st.info(reason)
112
-
113
- elif section == "📊 Text Analysis":
114
- st.title("📊 Violation Analysis Dashboard")
115
- if st.session_state.history:
116
- df = pd.DataFrame(st.session_state.history)
117
- label_counts = df["label"].value_counts().reset_index()
118
- label_counts.columns = ["Category", "Count"]
119
- fig = px.pie(label_counts, names="Category", values="Count", title="Offensive Category Distribution", color_discrete_sequence=px.colors.sequential.RdBu)
120
- st.plotly_chart(fig)
121
-
122
- st.markdown("### 🧾 Offensive Terms & Suggestions")
123
- for item in st.session_state.history:
124
- st.markdown(f"- 🔹 **Input:** `{item['text']}`")
125
- st.markdown(f" - ✨ **Translated:** `{item['translated']}`")
126
- st.markdown(f" - ❗ **Label:** `{item['label']}` with **{item['score']:.2%}** confidence")
127
- st.markdown(f" - 🔧 **Suggestion:** {item['reason']}")
128
-
129
- radar_df = pd.DataFrame({
130
- "Category": ["Insult", "Abuse", "Discrimination", "Hate Speech", "Vulgarity"],
131
- "Score": [0.7, 0.4, 0.3, 0.5, 0.6]
132
- })
133
- radar_fig = px.line_polar(radar_df, r='Score', theta='Category', line_close=True, title="⚠️ Risk Radar by Category")
134
- st.plotly_chart(radar_fig)
135
  else:
136
- st.info("⚠️ No classification data available yet.")
137
-
138
- elif section == "🛠️ Agent Build":
139
- st.title("🛠️ Agent Assistant for Text Classification")
140
- st.markdown("Upload context files and interact with an assistant to guide text moderation.")
141
-
142
- if uploaded_reference is not None:
143
- content = uploaded_reference.read().decode("utf-8")
144
- st.text_area("📄 Uploaded Reference Preview:", content, height=300)
145
-
146
- prompt = st.text_area("💬 Ask the Assistant Anything:", "How can I improve detection on emotional slang?")
147
-
148
- if st.button("💡 Analyze with Agent"):
149
- st.info("(This is a placeholder for future integration with a fine-tuned LLM or API call.)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
 
 
3
  import pandas as pd
4
+ import torch
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import re
8
+
9
+ # 页面配置
10
+ st.set_page_config(page_title="🛡️ 智盾内容安全审核平台", layout="wide")
11
+ PAGES = {
12
+ "🏠 首页": "home",
13
+ "🏢 金融行业审核": "finance",
14
+ "🏛 政府行业审核": "government",
15
+ "🌐 互联网行业审核": "internet",
16
+ "🧠 产品能力": "capability",
17
+ "✍️ 文本校正": "text_correction",
18
+ "📄 文档效正": "doc_alignment",
19
+ "🎙️ 语音检测": "speech_check",
20
+ "💼 加入我们": "join_us",
21
+ "💬 客户反馈": "feedback"
22
  }
23
+ if "page" not in st.session_state:
24
+ st.session_state.page = "home"
25
+ selected_page = st.sidebar.radio("📂 页面导航", list(PAGES.keys()))
26
+ st.session_state.page = PAGES[selected_page]
27
+
28
+ def parse_scores_from_llm_output(text):
29
+ matches = re.findall(r"([\u4e00-\u9fa5A-Za-z]+)[::]?\s*([0]\.\d+|1\.0+)", text)
30
+ score_dict = {}
31
+ for label, score in matches:
32
+ try:
33
+ score_dict[label.strip()] = float(score)
34
+ except:
35
+ continue
36
+ return score_dict
37
+
38
+ def plot_radar_chart(score_dict):
39
+ labels = list(score_dict.keys())
40
+ scores = list(score_dict.values())
41
+ angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
42
+ scores += scores[:1]
43
+ angles += angles[:1]
44
+ fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True))
45
+ ax.plot(angles, scores, "o-", linewidth=2)
46
+ ax.fill(angles, scores, alpha=0.25)
47
+ ax.set_thetagrids(np.degrees(angles[:-1]), labels)
48
+ ax.set_ylim(0, 1)
49
+ ax.set_title("📊 风险维度雷达图")
50
+ st.pyplot(fig)
51
+
52
+ def plot_bar_chart(score_dict):
53
+ labels = list(score_dict.keys())
54
+ scores = list(score_dict.values())
55
+ fig, ax = plt.subplots()
56
+ ax.barh(labels, scores)
57
+ ax.set_xlim(0, 1)
58
+ ax.set_xlabel("分数 (0-1)")
59
+ ax.set_title("📊 风险维度条形图")
60
+ st.pyplot(fig)
61
+
62
+ # ======================== 页面渲染 =========================
63
+ model_map = {
64
+ "BERT": "uer/roberta-base-finetuned-jd-binary-chinese",
65
+ "GPT": "IDEA-CCNL/Taiyi-CLUE-small",
66
+ "DeepSeek": "deepseek-ai/deepseek-llm-7b-chat"
67
+ }
68
+ prompt_list = [
69
+ "这条评论是否包含攻击性言论?请判断并解释。",
70
+ "该文本是否涉及政治敏感内容?",
71
+ "请判断是否应该屏蔽以下评论内容。",
72
+ "请对以下金融评论的合规性进行评分。",
73
+ "请检测是否包含歧视性内容。",
74
+ "请对这条弹幕进行情绪分类。",
75
+ "请指出文本中是否存在隐晦攻击。",
76
+ "这是否是仇恨言论?请给出理由。",
77
+ "请判断该用户评论是否违反平台规则。",
78
+ "请指出内容是否存在偏见或刻板印象。"
79
+ ]
80
+
81
+ def render_audit_page(title, task_prompt):
82
+ st.title(title)
83
+ model_choice = st.selectbox("🧠 选择模型", list(model_map.keys()))
84
+ model_name = model_map[model_choice]
85
+ input_type = st.radio("输入方式", ["单条输入", "上传CSV文件"])
86
+ user_text = None
87
+ if input_type == "单条输入":
88
+ user_text = st.text_area("请输入文本:", height=150)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  else:
90
+ uploaded_file = st.file_uploader("上传CSV文件(包含'text'列)", type=["csv"])
91
+
92
+ if st.button("🚨 开始审核"):
93
+ with st.spinner("审核中..."):
94
+ if model_choice == "DeepSeek":
95
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
96
+ model = AutoModelForCausalLM.from_pretrained(
97
+ model_name,
98
+ device_map="auto",
99
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
100
+ trust_remote_code=True
101
+ )
102
+ if user_text:
103
+ prompt = f"""
104
+ 你是一个中文内容审核助手,请对以下文本内容进行分析。
105
+
106
+ 任务包括:
107
+ 1. 判断是否存在违规内容,并简要解释原因。
108
+ 2. 给出以下五个风险维度的评分(0-1之间,越高代表越严重):
109
+ - 人身攻击
110
+ - 暴力
111
+ - 政治敏感
112
+ - 舆论攻击
113
+ - 歧视成分
114
+
115
+ 输出格式如下:
116
+ 内容:<原始文本>
117
+ 解释:<判断和解释>
118
+ 维度评分:
119
+ - 人身攻击: x.xx
120
+ - 暴力: x.xx
121
+ - 政治敏感: x.xx
122
+ - 舆论攻击: x.xx
123
+ - 歧视成分: x.xx
124
+
125
+ 文本如下:
126
+ {user_text}
127
+ """
128
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
129
+ outputs = model.generate(**inputs, max_new_tokens=512)
130
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
131
+ explanation = result.split("\n", 1)[-1]
132
+ st.markdown("### 📋 审核结果")
133
+ st.write(explanation)
134
+ scores = parse_scores_from_llm_output(explanation)
135
+ if len(scores) >= 3:
136
+ plot_radar_chart(scores)
137
+ plot_bar_chart(scores)
138
+ else:
139
+ st.info("未提取出结构化评分维度")
140
+ else:
141
+ classifier = pipeline("text-classification", model=model_name, tokenizer=model_name, device=0 if torch.cuda.is_available() else -1)
142
+ if user_text:
143
+ result = classifier(user_text)[0]
144
+ st.markdown("### 📋 审核结果")
145
+ st.write(f"标签: {result['label']} / 置信度: {result['score']:.2f}")
146
+ elif uploaded_file:
147
+ df = pd.read_csv(uploaded_file)
148
+ if 'text' not in df.columns:
149
+ st.error("CSV 文件需包含 'text' 列")
150
+ else:
151
+ df["预测标签"] = df["text"].apply(lambda x: classifier(x)[0]['label'])
152
+ st.dataframe(df)
153
+
154
+ # 页面渲染逻辑
155
+ if st.session_state.page == "home":
156
+ st.title("🛡️ 智盾内容安全审核平台")
157
+ st.markdown("欢迎使用智盾平台,本系统为政府、金融、互联网行业提供智能内容安全审核服务。\n\n请选择左侧行业进入审核流程。")
158
+
159
+ elif st.session_state.page == "finance":
160
+ render_audit_page("🏢 金融行业内容审核", "请审核金融评论内容是否存在合规风险")
161
+
162
+ elif st.session_state.page == "government":
163
+ render_audit_page("🏛 政府行业内容审核", "请判断该内容是否存在政治敏感或违规用语")
164
+
165
+ elif st.session_state.page == "internet":
166
+ render_audit_page("🌐 互联网内容审核(文本/语音/弹幕)", "请分析该用户生成内容是否违规")
167
+
168
+ else:
169
+ st.title(f"🧩 {selected_page}")
170
+ st.info("🚧 此模块为占位页面,后续即将上线。")
171
+
172
+ elif st.session_state.page == "capability":
173
+ st.title("🧠 产品能力")
174
+ st.markdown("本平台具备全面的文本纠错、审核、比对与生成能力,适用于政务、金融、媒体等场景。")
175
+
176
+ features = [
177
+ ("📝 字词错误", "错别字、音近字、形近字、多字、重叠、颠倒、异形词等"),
178
+ ("📌 常识错误", "标点符号、地名关联、表达不当、语义错误、不语名词等"),
179
+ ("🚫 敏感词过滤", "涉及暴恐、色情、违禁、侮辱、歧视等不健康词语"),
180
+ ("⚠️ 政治性差错", "领导人姓名、职务、讲话、政治口号、固定表述等"),
181
+ ("📄 文本比对", "快速找出两个文本之间的差异之处,高清高亮显示"),
182
+ ("📐 格式错误", "参照国家标准和党政公文规范,自动识别格式问题"),
183
+ ("🤖 智能写作", "自动生成新闻稿、公告、任务文书,响应快速"),
184
+ ("🌐 网站巡检", "自动抓取网页历史快照,输出违规风险报告")
185
+ ]
186
+
187
+ for i in range(0, len(features), 2):
188
+ col1, col2 = st.columns(2)
189
+ with col1:
190
+ with st.expander(features[i][0], expanded=True):
191
+ st.markdown(f"**功能描述:** {features[i][1]}")
192
+ st.button(f"👉 体验 {features[i][0]}", key=f"btn_{i}")
193
+ if i+1 < len(features):
194
+ with col2:
195
+ with st.expander(features[i+1][0], expanded=True):
196
+ st.markdown(f"**功能描述:** {features[i+1][1]}")
197
+ st.button(f"👉 体验 {features[i+1][0]}", key=f"btn_{i+1}")