topic-arena-demo / main.py
kardosdrur's picture
Added corpus
08620e1
raw
history blame
3.24 kB
import dash_mantine_components as dmc
import joblib
import numpy as np
from dash_extensions.enrich import (Dash, DashBlueprint, Input, Output, State,
dcc, exceptions, html)
from sentence_transformers import SentenceTransformer
from sklearn.datasets import fetch_20newsgroups
from topicwizard.widgets import (ConceptClusters, DocumentClusters,
TopicBrowser, TopicHierarchy,
create_widget_container)
from turftopic import ClusteringTopicModel, KeyNMF
def create_app(blueprint):
app = Dash(
__name__,
blueprint=blueprint,
title="topicwizard",
external_scripts=[
{
"src": "https://cdn.tailwindcss.com",
},
{
"src": "https://kit.fontawesome.com/9640e5cd85.js",
"crossorigin": "anonymous",
},
],
)
return app
with open("corpus.txt") as in_file:
corpus = in_file.read().split("\n")
print("Calculating embeddings")
encoder = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1")
embeddings = encoder.encode(corpus, show_progress_bar=True)
print("Fitting keynmf")
keynmf = KeyNMF(5, encoder=encoder, random_state=42)
keynmf_data = keynmf.prepare_topic_data(corpus, embeddings=embeddings)
keynmf_data.hierarchy.divide_children(5)
print("Fitting top2vec")
top2vec = ClusteringTopicModel(
n_reduce_to=5,
feature_importance="centroid",
encoder=encoder,
random_state=0,
)
top2vec_data = top2vec.prepare_topic_data(corpus, embeddings=embeddings)
print("Building blueprints.")
keynmf_blueprint = create_widget_container(
[TopicBrowser(), ConceptClusters(), TopicHierarchy()],
keynmf_data,
app_id="keynmf",
)
top2vec_blueprint = create_widget_container(
[TopicBrowser(), DocumentClusters(), TopicHierarchy()],
top2vec_data,
app_id="top2vec",
)
app_blueprint = DashBlueprint()
app_blueprint.layout = html.Div(
dmc.Group(
[
html.Div(
[
dmc.Text(
"KeyNMF",
size="xl",
fw=700,
color="blue.9",
align="center",
className="pt-8",
),
keynmf_blueprint.layout,
],
className="h-full flex-1 items-center",
),
html.Div(
[
dmc.Text(
"Top2Vec",
size="xl",
fw=700,
color="teal.9",
align="center",
className="pt-8",
),
top2vec_blueprint.layout,
],
className="h-full flex-1 items-center",
),
],
grow=True,
className="h-full flex-1",
),
className="""
w-full h-full flex-col flex items-stretch fixed
bg-white
""",
)
app = create_app(app_blueprint)
server = app.server
if __name__ == "__main__":
app.run_server(debug=False, port=7860)