Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| import pandas as pd | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import re | |
| # 页面配置 | |
| st.set_page_config(page_title="🛡️ 智盾内容安全审核平台", layout="wide") | |
| PAGES = { | |
| "🏠 首页": "home", | |
| "🏢 金融行业审核": "finance", | |
| "🏛 政府行业审核": "government", | |
| "🌐 互联网行业审核": "internet", | |
| "🧠 产品能力": "capability", | |
| "✍️ 文本校正": "text_correction", | |
| "📄 文档效正": "doc_alignment", | |
| "🎙️ 语音检测": "speech_check", | |
| "💼 加入我们": "join_us", | |
| "💬 客户反馈": "feedback" | |
| } | |
| if "page" not in st.session_state: | |
| st.session_state.page = "home" | |
| selected_page = st.sidebar.radio("📂 页面导航", list(PAGES.keys())) | |
| st.session_state.page = PAGES[selected_page] | |
| def parse_scores_from_llm_output(text): | |
| matches = re.findall(r"([\u4e00-\u9fa5A-Za-z]+)[::]?\s*([0]\.\d+|1\.0+)", text) | |
| score_dict = {} | |
| for label, score in matches: | |
| try: | |
| score_dict[label.strip()] = float(score) | |
| except: | |
| continue | |
| return score_dict | |
| def plot_radar_chart(score_dict): | |
| labels = list(score_dict.keys()) | |
| scores = list(score_dict.values()) | |
| angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist() | |
| scores += scores[:1] | |
| angles += angles[:1] | |
| fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True)) | |
| ax.plot(angles, scores, "o-", linewidth=2) | |
| ax.fill(angles, scores, alpha=0.25) | |
| ax.set_thetagrids(np.degrees(angles[:-1]), labels) | |
| ax.set_ylim(0, 1) | |
| ax.set_title("📊 风险维度雷达图") | |
| st.pyplot(fig) | |
| def plot_bar_chart(score_dict): | |
| labels = list(score_dict.keys()) | |
| scores = list(score_dict.values()) | |
| fig, ax = plt.subplots() | |
| ax.barh(labels, scores) | |
| ax.set_xlim(0, 1) | |
| ax.set_xlabel("分数 (0-1)") | |
| ax.set_title("📊 风险维度条形图") | |
| st.pyplot(fig) | |
| # ======================== 页面渲染 ========================= | |
| model_map = { | |
| "BERT": "uer/roberta-base-finetuned-jd-binary-chinese", | |
| "GPT": "IDEA-CCNL/Taiyi-CLUE-small", | |
| "DeepSeek": "deepseek-ai/deepseek-llm-7b-chat" | |
| } | |
| prompt_list = [ | |
| "这条评论是否包含攻击性言论?请判断并解释。", | |
| "该文本是否涉及政治敏感内容?", | |
| "请判断是否应该屏蔽以下评论内容。", | |
| "请对以下金融评论的合规性进行评分。", | |
| "请检测是否包含歧视性内容。", | |
| "请对这条弹幕进行情绪分类。", | |
| "请指出文本中是否存在隐晦攻击。", | |
| "这是否是仇恨言论?请给出理由。", | |
| "请判断该用户评论是否违反平台规则。", | |
| "请指出内容是否存在偏见或刻板印象。" | |
| ] | |
| def render_audit_page(title, task_prompt): | |
| st.title(title) | |
| model_choice = st.selectbox("🧠 选择模型", list(model_map.keys())) | |
| model_name = model_map[model_choice] | |
| input_type = st.radio("输入方式", ["单条输入", "上传CSV文件"]) | |
| user_text = None | |
| if input_type == "单条输入": | |
| user_text = st.text_area("请输入文本:", height=150) | |
| else: | |
| uploaded_file = st.file_uploader("上传CSV文件(包含'text'列)", type=["csv"]) | |
| if st.button("🚨 开始审核"): | |
| with st.spinner("审核中..."): | |
| if model_choice == "DeepSeek": | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto", | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| trust_remote_code=True | |
| ) | |
| if user_text: | |
| prompt = f""" | |
| 你是一个中文内容审核助手,请对以下文本内容进行分析。 | |
| 任务包括: | |
| 1. 判断是否存在违规内容,并简要解释原因。 | |
| 2. 给出以下五个风险维度的评分(0-1之间,越高代表越严重): | |
| - 人身攻击 | |
| - 暴力 | |
| - 政治敏感 | |
| - 舆论攻击 | |
| - 歧视成分 | |
| 输出格式如下: | |
| 内容:<原始文本> | |
| 解释:<判断和解释> | |
| 维度评分: | |
| - 人身攻击: x.xx | |
| - 暴力: x.xx | |
| - 政治敏感: x.xx | |
| - 舆论攻击: x.xx | |
| - 歧视成分: x.xx | |
| 文本如下: | |
| {user_text} | |
| """ | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| outputs = model.generate(**inputs, max_new_tokens=512) | |
| result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| explanation = result.split("\n", 1)[-1] | |
| st.markdown("### 📋 审核结果") | |
| st.write(explanation) | |
| scores = parse_scores_from_llm_output(explanation) | |
| if len(scores) >= 3: | |
| plot_radar_chart(scores) | |
| plot_bar_chart(scores) | |
| else: | |
| st.info("未提取出结构化评分维度") | |
| else: | |
| classifier = pipeline("text-classification", model=model_name, tokenizer=model_name, device=0 if torch.cuda.is_available() else -1) | |
| if user_text: | |
| result = classifier(user_text)[0] | |
| st.markdown("### 📋 审核结果") | |
| st.write(f"标签: {result['label']} / 置信度: {result['score']:.2f}") | |
| elif uploaded_file: | |
| df = pd.read_csv(uploaded_file) | |
| if 'text' not in df.columns: | |
| st.error("CSV 文件需包含 'text' 列") | |
| else: | |
| df["预测标签"] = df["text"].apply(lambda x: classifier(x)[0]['label']) | |
| st.dataframe(df) | |
| # 页面渲染逻辑 | |
| if st.session_state.page == "home": | |
| st.title("🛡️ 智盾内容安全审核平台") | |
| st.markdown("欢迎使用智盾平台,本系统为政府、金融、互联网行业提供智能内容安全审核服务。\n\n请选择左侧行业进入审核流程。") | |
| elif st.session_state.page == "finance": | |
| render_audit_page("🏢 金融行业内容审核", "请审核金融评论内容是否存在合规风险") | |
| elif st.session_state.page == "government": | |
| render_audit_page("🏛 政府行业内容审核", "请判断该内容是否存在政治敏感或违规用语") | |
| elif st.session_state.page == "internet": | |
| render_audit_page("🌐 互联网内容审核(文本/语音/弹幕)", "请分析该用户生成内容是否违规") | |
| else: | |
| st.title(f"🧩 {selected_page}") | |
| st.info("🚧 此模块为占位页面,后续即将上线。") | |
| elif st.session_state.page == "capability": | |
| st.title("🧠 产品能力") | |
| st.markdown("本平台具备全面的文本纠错、审核、比对与生成能力,适用于政务、金融、媒体等场景。") | |
| features = [ | |
| ("📝 字词错误", "错别字、音近字、形近字、多字、重叠、颠倒、异形词等"), | |
| ("📌 常识错误", "标点符号、地名关联、表达不当、语义错误、不语名词等"), | |
| ("🚫 敏感词过滤", "涉及暴恐、色情、违禁、侮辱、歧视等不健康词语"), | |
| ("⚠️ 政治性差错", "领导人姓名、职务、讲话、政治口号、固定表述等"), | |
| ("📄 文本比对", "快速找出两个文本之间的差异之处,高清高亮显示"), | |
| ("📐 格式错误", "参照国家标准和党政公文规范,自动识别格式问题"), | |
| ("🤖 智能写作", "自动生成新闻稿、公告、任务文书,响应快速"), | |
| ("🌐 网站巡检", "自动抓取网页历史快照,输出违规风险报告") | |
| ] | |
| for i in range(0, len(features), 2): | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| with st.expander(features[i][0], expanded=True): | |
| st.markdown(f"**功能描述:** {features[i][1]}") | |
| st.button(f"👉 体验 {features[i][0]}", key=f"btn_{i}") | |
| if i+1 < len(features): | |
| with col2: | |
| with st.expander(features[i+1][0], expanded=True): | |
| st.markdown(f"**功能描述:** {features[i+1][1]}") | |
| st.button(f"👉 体验 {features[i+1][0]}", key=f"btn_{i+1}") | |