Spaces:
Sleeping
Sleeping
| import dash_mantine_components as dmc | |
| import joblib | |
| import numpy as np | |
| from dash_extensions.enrich import (Dash, DashBlueprint, Input, Output, State, | |
| dcc, exceptions, html) | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.datasets import fetch_20newsgroups | |
| from topicwizard.widgets import (ConceptClusters, DocumentClusters, | |
| TopicBrowser, TopicHierarchy, | |
| create_widget_container) | |
| from turftopic import ClusteringTopicModel, KeyNMF | |
| def create_app(blueprint): | |
| app = Dash( | |
| __name__, | |
| blueprint=blueprint, | |
| title="topicwizard", | |
| external_scripts=[ | |
| { | |
| "src": "https://cdn.tailwindcss.com", | |
| }, | |
| { | |
| "src": "https://kit.fontawesome.com/9640e5cd85.js", | |
| "crossorigin": "anonymous", | |
| }, | |
| ], | |
| ) | |
| return app | |
| with open("corpus.txt") as in_file: | |
| corpus = in_file.read().split("\n") | |
| print("Calculating embeddings") | |
| encoder = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1") | |
| embeddings = encoder.encode(corpus, show_progress_bar=True) | |
| print("Fitting keynmf") | |
| keynmf = KeyNMF(5, encoder=encoder, random_state=42) | |
| keynmf_data = keynmf.prepare_topic_data(corpus, embeddings=embeddings) | |
| keynmf_data.hierarchy.divide_children(5) | |
| print("Fitting top2vec") | |
| top2vec = ClusteringTopicModel( | |
| n_reduce_to=5, | |
| feature_importance="centroid", | |
| encoder=encoder, | |
| random_state=0, | |
| ) | |
| top2vec_data = top2vec.prepare_topic_data(corpus, embeddings=embeddings) | |
| print("Building blueprints.") | |
| keynmf_blueprint = create_widget_container( | |
| [TopicBrowser(), ConceptClusters(), TopicHierarchy()], | |
| keynmf_data, | |
| app_id="keynmf", | |
| ) | |
| top2vec_blueprint = create_widget_container( | |
| [TopicBrowser(), DocumentClusters(), TopicHierarchy()], | |
| top2vec_data, | |
| app_id="top2vec", | |
| ) | |
| app_blueprint = DashBlueprint() | |
| app_blueprint.layout = html.Div( | |
| dmc.Group( | |
| [ | |
| html.Div( | |
| [ | |
| dmc.Text( | |
| "KeyNMF", | |
| size="xl", | |
| fw=700, | |
| color="blue.9", | |
| align="center", | |
| className="pt-8", | |
| ), | |
| keynmf_blueprint.layout, | |
| ], | |
| className="h-full flex-1 items-center", | |
| ), | |
| html.Div( | |
| [ | |
| dmc.Text( | |
| "Top2Vec", | |
| size="xl", | |
| fw=700, | |
| color="teal.9", | |
| align="center", | |
| className="pt-8", | |
| ), | |
| top2vec_blueprint.layout, | |
| ], | |
| className="h-full flex-1 items-center", | |
| ), | |
| ], | |
| grow=True, | |
| className="h-full flex-1", | |
| ), | |
| className=""" | |
| w-full h-full flex-col flex items-stretch fixed | |
| bg-white | |
| """, | |
| ) | |
| app = create_app(app_blueprint) | |
| server = app.server | |
| if __name__ == "__main__": | |
| app.run_server(debug=False, port=7860) | |