Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,149 +1,197 @@
|
|
| 1 |
-
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
|
| 2 |
-
import torch
|
| 3 |
import streamlit as st
|
| 4 |
-
from
|
| 5 |
-
import pytesseract
|
| 6 |
-
import openai
|
| 7 |
import pandas as pd
|
| 8 |
-
import
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
"
|
| 23 |
-
"
|
| 24 |
-
"
|
|
|
|
| 25 |
}
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
st.
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
uploaded_file = st.file_uploader("Upload an image (JPG/PNG)", type=["jpg", "jpeg", "png"])
|
| 94 |
-
|
| 95 |
-
if uploaded_file is not None:
|
| 96 |
-
image = Image.open(uploaded_file)
|
| 97 |
-
st.image(image, caption="Uploaded Screenshot", use_column_width=True)
|
| 98 |
-
|
| 99 |
-
with st.spinner("🧠 Extracting text via OCR..."):
|
| 100 |
-
ocr_text = pytesseract.image_to_string(image, lang="chi_sim+eng")
|
| 101 |
-
st.markdown("#### 📋 Extracted Text:")
|
| 102 |
-
st.code(ocr_text.strip())
|
| 103 |
-
|
| 104 |
-
translated, label, score, reason = classify_emoji_text(ocr_text.strip())
|
| 105 |
-
st.markdown("### 🔄 Translated sentence:")
|
| 106 |
-
st.code(translated, language="text")
|
| 107 |
-
|
| 108 |
-
st.markdown(f"### 🎯 Prediction: `{label}`")
|
| 109 |
-
st.markdown(f"### 📊 Confidence Score: `{score:.2%}`")
|
| 110 |
-
st.markdown("### 🧠 Model Explanation:")
|
| 111 |
-
st.info(reason)
|
| 112 |
-
|
| 113 |
-
elif section == "📊 Text Analysis":
|
| 114 |
-
st.title("📊 Violation Analysis Dashboard")
|
| 115 |
-
if st.session_state.history:
|
| 116 |
-
df = pd.DataFrame(st.session_state.history)
|
| 117 |
-
label_counts = df["label"].value_counts().reset_index()
|
| 118 |
-
label_counts.columns = ["Category", "Count"]
|
| 119 |
-
fig = px.pie(label_counts, names="Category", values="Count", title="Offensive Category Distribution", color_discrete_sequence=px.colors.sequential.RdBu)
|
| 120 |
-
st.plotly_chart(fig)
|
| 121 |
-
|
| 122 |
-
st.markdown("### 🧾 Offensive Terms & Suggestions")
|
| 123 |
-
for item in st.session_state.history:
|
| 124 |
-
st.markdown(f"- 🔹 **Input:** `{item['text']}`")
|
| 125 |
-
st.markdown(f" - ✨ **Translated:** `{item['translated']}`")
|
| 126 |
-
st.markdown(f" - ❗ **Label:** `{item['label']}` with **{item['score']:.2%}** confidence")
|
| 127 |
-
st.markdown(f" - 🔧 **Suggestion:** {item['reason']}")
|
| 128 |
-
|
| 129 |
-
radar_df = pd.DataFrame({
|
| 130 |
-
"Category": ["Insult", "Abuse", "Discrimination", "Hate Speech", "Vulgarity"],
|
| 131 |
-
"Score": [0.7, 0.4, 0.3, 0.5, 0.6]
|
| 132 |
-
})
|
| 133 |
-
radar_fig = px.line_polar(radar_df, r='Score', theta='Category', line_close=True, title="⚠️ Risk Radar by Category")
|
| 134 |
-
st.plotly_chart(radar_fig)
|
| 135 |
else:
|
| 136 |
-
st.
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
|
|
|
|
|
|
|
| 3 |
import pandas as pd
|
| 4 |
+
import torch
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import numpy as np
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
# 页面配置
|
| 10 |
+
st.set_page_config(page_title="🛡️ 智盾内容安全审核平台", layout="wide")
|
| 11 |
+
PAGES = {
|
| 12 |
+
"🏠 首页": "home",
|
| 13 |
+
"🏢 金融行业审核": "finance",
|
| 14 |
+
"🏛 政府行业审核": "government",
|
| 15 |
+
"🌐 互联网行业审核": "internet",
|
| 16 |
+
"🧠 产品能力": "capability",
|
| 17 |
+
"✍️ 文本校正": "text_correction",
|
| 18 |
+
"📄 文档效正": "doc_alignment",
|
| 19 |
+
"🎙️ 语音检测": "speech_check",
|
| 20 |
+
"💼 加入我们": "join_us",
|
| 21 |
+
"💬 客户反馈": "feedback"
|
| 22 |
}
|
| 23 |
+
if "page" not in st.session_state:
|
| 24 |
+
st.session_state.page = "home"
|
| 25 |
+
selected_page = st.sidebar.radio("📂 页面导航", list(PAGES.keys()))
|
| 26 |
+
st.session_state.page = PAGES[selected_page]
|
| 27 |
+
|
| 28 |
+
def parse_scores_from_llm_output(text):
|
| 29 |
+
matches = re.findall(r"([\u4e00-\u9fa5A-Za-z]+)[::]?\s*([0]\.\d+|1\.0+)", text)
|
| 30 |
+
score_dict = {}
|
| 31 |
+
for label, score in matches:
|
| 32 |
+
try:
|
| 33 |
+
score_dict[label.strip()] = float(score)
|
| 34 |
+
except:
|
| 35 |
+
continue
|
| 36 |
+
return score_dict
|
| 37 |
+
|
| 38 |
+
def plot_radar_chart(score_dict):
|
| 39 |
+
labels = list(score_dict.keys())
|
| 40 |
+
scores = list(score_dict.values())
|
| 41 |
+
angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
|
| 42 |
+
scores += scores[:1]
|
| 43 |
+
angles += angles[:1]
|
| 44 |
+
fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True))
|
| 45 |
+
ax.plot(angles, scores, "o-", linewidth=2)
|
| 46 |
+
ax.fill(angles, scores, alpha=0.25)
|
| 47 |
+
ax.set_thetagrids(np.degrees(angles[:-1]), labels)
|
| 48 |
+
ax.set_ylim(0, 1)
|
| 49 |
+
ax.set_title("📊 风险维度雷达图")
|
| 50 |
+
st.pyplot(fig)
|
| 51 |
+
|
| 52 |
+
def plot_bar_chart(score_dict):
|
| 53 |
+
labels = list(score_dict.keys())
|
| 54 |
+
scores = list(score_dict.values())
|
| 55 |
+
fig, ax = plt.subplots()
|
| 56 |
+
ax.barh(labels, scores)
|
| 57 |
+
ax.set_xlim(0, 1)
|
| 58 |
+
ax.set_xlabel("分数 (0-1)")
|
| 59 |
+
ax.set_title("📊 风险维度条形图")
|
| 60 |
+
st.pyplot(fig)
|
| 61 |
+
|
| 62 |
+
# ======================== 页面渲染 =========================
|
| 63 |
+
model_map = {
|
| 64 |
+
"BERT": "uer/roberta-base-finetuned-jd-binary-chinese",
|
| 65 |
+
"GPT": "IDEA-CCNL/Taiyi-CLUE-small",
|
| 66 |
+
"DeepSeek": "deepseek-ai/deepseek-llm-7b-chat"
|
| 67 |
+
}
|
| 68 |
+
prompt_list = [
|
| 69 |
+
"这条评论是否包含攻击性言论?请判断并解释。",
|
| 70 |
+
"该文本是否涉及政治敏感内容?",
|
| 71 |
+
"请判断是否应该屏蔽以下评论内容。",
|
| 72 |
+
"请对以下金融评论的合规性进行评分。",
|
| 73 |
+
"请检测是否包含歧视性内容。",
|
| 74 |
+
"请对这条弹幕进行情绪分类。",
|
| 75 |
+
"请指出文本中是否存在隐晦攻击。",
|
| 76 |
+
"这是否是仇恨言论?请给出理由。",
|
| 77 |
+
"请判断该用户评论是否违反平台规则。",
|
| 78 |
+
"请指出内容是否存在偏见或刻板印象。"
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
def render_audit_page(title, task_prompt):
|
| 82 |
+
st.title(title)
|
| 83 |
+
model_choice = st.selectbox("🧠 选择模型", list(model_map.keys()))
|
| 84 |
+
model_name = model_map[model_choice]
|
| 85 |
+
input_type = st.radio("输入方式", ["单条输入", "上传CSV文件"])
|
| 86 |
+
user_text = None
|
| 87 |
+
if input_type == "单条输入":
|
| 88 |
+
user_text = st.text_area("请输入文本:", height=150)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
else:
|
| 90 |
+
uploaded_file = st.file_uploader("上传CSV文件(包含'text'列)", type=["csv"])
|
| 91 |
+
|
| 92 |
+
if st.button("🚨 开始审核"):
|
| 93 |
+
with st.spinner("审核中..."):
|
| 94 |
+
if model_choice == "DeepSeek":
|
| 95 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 96 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 97 |
+
model_name,
|
| 98 |
+
device_map="auto",
|
| 99 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
| 100 |
+
trust_remote_code=True
|
| 101 |
+
)
|
| 102 |
+
if user_text:
|
| 103 |
+
prompt = f"""
|
| 104 |
+
你是一个中文内容审核助手,请对以下文本内容进行分析。
|
| 105 |
+
|
| 106 |
+
任务包括:
|
| 107 |
+
1. 判断是否存在违规内容,并简要解释原因。
|
| 108 |
+
2. 给出以下五个风险维度的评分(0-1之间,越高代表越严重):
|
| 109 |
+
- 人身攻击
|
| 110 |
+
- 暴力
|
| 111 |
+
- 政治敏感
|
| 112 |
+
- 舆论攻击
|
| 113 |
+
- 歧视成分
|
| 114 |
+
|
| 115 |
+
输出格式如下:
|
| 116 |
+
内容:<原始文本>
|
| 117 |
+
解释:<判断和解释>
|
| 118 |
+
维度评分:
|
| 119 |
+
- 人身攻击: x.xx
|
| 120 |
+
- 暴力: x.xx
|
| 121 |
+
- 政治敏感: x.xx
|
| 122 |
+
- 舆论攻击: x.xx
|
| 123 |
+
- 歧视成分: x.xx
|
| 124 |
+
|
| 125 |
+
文本如下:
|
| 126 |
+
{user_text}
|
| 127 |
+
"""
|
| 128 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 129 |
+
outputs = model.generate(**inputs, max_new_tokens=512)
|
| 130 |
+
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 131 |
+
explanation = result.split("\n", 1)[-1]
|
| 132 |
+
st.markdown("### 📋 审核结果")
|
| 133 |
+
st.write(explanation)
|
| 134 |
+
scores = parse_scores_from_llm_output(explanation)
|
| 135 |
+
if len(scores) >= 3:
|
| 136 |
+
plot_radar_chart(scores)
|
| 137 |
+
plot_bar_chart(scores)
|
| 138 |
+
else:
|
| 139 |
+
st.info("未提取出结构化评分维度")
|
| 140 |
+
else:
|
| 141 |
+
classifier = pipeline("text-classification", model=model_name, tokenizer=model_name, device=0 if torch.cuda.is_available() else -1)
|
| 142 |
+
if user_text:
|
| 143 |
+
result = classifier(user_text)[0]
|
| 144 |
+
st.markdown("### 📋 审核结果")
|
| 145 |
+
st.write(f"标签: {result['label']} / 置信度: {result['score']:.2f}")
|
| 146 |
+
elif uploaded_file:
|
| 147 |
+
df = pd.read_csv(uploaded_file)
|
| 148 |
+
if 'text' not in df.columns:
|
| 149 |
+
st.error("CSV 文件需包含 'text' 列")
|
| 150 |
+
else:
|
| 151 |
+
df["预测标签"] = df["text"].apply(lambda x: classifier(x)[0]['label'])
|
| 152 |
+
st.dataframe(df)
|
| 153 |
+
|
| 154 |
+
# 页面渲染逻辑
|
| 155 |
+
if st.session_state.page == "home":
|
| 156 |
+
st.title("🛡️ 智盾内容安全审核平台")
|
| 157 |
+
st.markdown("欢迎使用智盾平台,本系统为政府、金融、互联网行业提供智能内容安全审核服务。\n\n请选择左侧行业进入审核流程。")
|
| 158 |
+
|
| 159 |
+
elif st.session_state.page == "finance":
|
| 160 |
+
render_audit_page("🏢 金融行业内容审核", "请审核金融评论内容是否存在合规风险")
|
| 161 |
+
|
| 162 |
+
elif st.session_state.page == "government":
|
| 163 |
+
render_audit_page("🏛 政府行业内容审核", "请判断该内容是否存在政治敏感或违规用语")
|
| 164 |
+
|
| 165 |
+
elif st.session_state.page == "internet":
|
| 166 |
+
render_audit_page("🌐 互联网内容审核(文本/语音/弹幕)", "请分析该用户生成内容是否违规")
|
| 167 |
+
|
| 168 |
+
else:
|
| 169 |
+
st.title(f"🧩 {selected_page}")
|
| 170 |
+
st.info("🚧 此模块为占位页面,后续即将上线。")
|
| 171 |
+
|
| 172 |
+
elif st.session_state.page == "capability":
|
| 173 |
+
st.title("🧠 产品能力")
|
| 174 |
+
st.markdown("本平台具备全面的文本纠错、审核、比对与生成能力,适用于政务、金融、媒体等场景。")
|
| 175 |
+
|
| 176 |
+
features = [
|
| 177 |
+
("📝 字词错误", "错别字、音近字、形近字、多字、重叠、颠倒、异形词等"),
|
| 178 |
+
("📌 常识错误", "标点符号、地名关联、表达不当、语义错误、不语名词等"),
|
| 179 |
+
("🚫 敏感词过滤", "涉及暴恐、色情、违禁、侮辱、歧视等不健康词语"),
|
| 180 |
+
("⚠️ 政治性差错", "领导人姓名、职务、讲话、政治口号、固定表述等"),
|
| 181 |
+
("📄 文本比对", "快速找出两个文本之间的差异之处,高清高亮显示"),
|
| 182 |
+
("📐 格式错误", "参照国家标准和党政公文规范,自动识别格式问题"),
|
| 183 |
+
("🤖 智能写作", "自动生成新闻稿、公告、任务文书,响应快速"),
|
| 184 |
+
("🌐 网站巡检", "自动抓取网页历史快照,输出违规风险报告")
|
| 185 |
+
]
|
| 186 |
+
|
| 187 |
+
for i in range(0, len(features), 2):
|
| 188 |
+
col1, col2 = st.columns(2)
|
| 189 |
+
with col1:
|
| 190 |
+
with st.expander(features[i][0], expanded=True):
|
| 191 |
+
st.markdown(f"**功能描述:** {features[i][1]}")
|
| 192 |
+
st.button(f"👉 体验 {features[i][0]}", key=f"btn_{i}")
|
| 193 |
+
if i+1 < len(features):
|
| 194 |
+
with col2:
|
| 195 |
+
with st.expander(features[i+1][0], expanded=True):
|
| 196 |
+
st.markdown(f"**功能描述:** {features[i+1][1]}")
|
| 197 |
+
st.button(f"👉 体验 {features[i+1][0]}", key=f"btn_{i+1}")
|