File size: 5,443 Bytes
2a88707
 
34d0739
f68537f
a05fede
34d0739
a05fede
34d0739
 
92d7276
2a88707
34d0739
 
 
2a88707
34d0739
823d236
 
 
 
 
 
 
2a88707
 
 
f068886
0a61355
92d7276
2a88707
34d0739
e88412a
 
 
9bc2f28
2a88707
34d0739
 
 
 
 
e88412a
2a88707
 
 
e88412a
 
 
9bc2f28
 
 
 
 
 
 
 
 
 
 
a05fede
9bc2f28
 
 
 
 
 
 
 
f68537f
 
2a88707
34d0739
2a88707
e88412a
2a88707
 
 
34d0739
2a88707
 
9bc2f28
f68537f
 
 
 
9bc2f28
f68537f
 
9bc2f28
f68537f
 
 
9bc2f28
34d0739
7d158e2
2a88707
f68537f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import streamlit as st
import cohere
import os
import base64

st.set_page_config(page_title="Cohere Chat", layout="wide")

AI_PFP = "media/pfps/cohere-pfp.png"
USER_PFP = "media/pfps/user-pfp.jpg"
BANNER = "media/banner.png"

if not os.path.exists(AI_PFP) or not os.path.exists(USER_PFP):
    st.error("Missing profile pictures in media/pfps directory")
    st.stop()

model_info = {
    "c4ai-aya-expanse-8b": {"description": "Aya Expanse is a highly performant 8B multilingual model, designed to rival monolingual performance through innovations in instruction tuning with data arbitrage, preference training, and model merging. Serves 23 languages.", "context": "4K", "output": "4K"},
    "c4ai-aya-expanse-32b": {"description": "Aya Expanse is a highly performant 32B multilingual model, designed to rival monolingual performance through innovations in instruction tuning with data arbitrage, preference training, and model merging. Serves 23 languages.", "context": "128K", "output": "4K"},
    "c4ai-aya-vision-8b": {"description": "Aya Vision is a state-of-the-art multimodal model excelling at a variety of critical benchmarks for language, text, and image capabilities. This 8 billion parameter variant is focused on low latency and best-in-class performance.", "context": "16K", "output": "4K"},
    "c4ai-aya-vision-32b": {"description": "Aya Vision is a state-of-the-art multimodal model excelling at a variety of critical benchmarks for language, text, and image capabilities. Serves 23 languages. This 32 billion parameter variant is focused on state-of-art multilingual performance.", "context": "16k", "output": "4K"},
    "command-a-03-2025": {"description": "Command A is our most performant model to date, excelling at tool use, agents, retrieval augmented generation (RAG), and multilingual use cases. Command A has a context length of 256K, only requires two GPUs to run, and has 150% higher throughput compared to Command R+ 08-2024.", "context": "256K", "output": "8K"},
    "command-r7b-12-2024": {"description": "command-r7b-12-2024 is a small, fast update delivered in December 2024. It excels at RAG, tool use, agents, and similar tasks requiring complex reasoning and multiple steps.", "context": "128K", "output": "4K"},
    "command-r-plus-04-2024": {"description": "Command R+ is an instruction-following conversational model that performs language tasks at a higher quality, more reliably, and with a longer context than previous models. It is best suited for complex RAG workflows and multi-step tool use.", "context": "128K", "output": "4K"},
}

with st.sidebar:
    st.image(BANNER, use_container_width=True)
    st.markdown("Hugging Face 🤗 Community UI (Vision Model support coming soon)")
    st.title("Settings")
    api_key = st.text_input("Cohere API Key", type="password")
    selected_model = st.selectbox("Model", options=list(model_info.keys()))
    if st.button("Clear Chat"):
        st.session_state.messages = []
        st.session_state.first_message_sent = False
        st.experimental_rerun()
    st.divider()
    st.image(AI_PFP, width=60)
    st.subheader(selected_model)
    st.markdown(model_info[selected_model]["description"])
    st.caption(f"Context: {model_info[selected_model]['context']}")
    st.caption(f"Output: {model_info[selected_model]['output']}")
    st.markdown("Powered by Cohere's API")

if "messages" not in st.session_state:
    st.session_state.messages = []
if "first_message_sent" not in st.session_state:
    st.session_state.first_message_sent = False

main = st.container()
with main:
    if not st.session_state.first_message_sent:
        st.markdown(
            "<h1 style='text-align:center; color:#4a4a4a; margin-top:100px;'>How can Cohere help you today?</h1>",
            unsafe_allow_html=True
        )
    for msg in st.session_state.messages:
        avatar = USER_PFP if msg["role"] == "user" else AI_PFP
        with st.chat_message(msg["role"], avatar=avatar):
            st.markdown(msg["content"])

    col1, col2 = st.columns([1, 4])
    with col1:
        if selected_model.startswith("c4ai-aya-vision"):
            uploaded = st.file_uploader("Upload image", type=["png", "jpg", "jpeg"])
        else:
            uploaded = None
    with col2:
        prompt = st.chat_input("Message...")

if prompt:
    if not api_key:
        st.error("API key required")
        st.stop()
    st.session_state.first_message_sent = True
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user", avatar=USER_PFP):
        st.markdown(prompt)

    try:
        co = cohere.ClientV2(api_key)
        user_content = [{"type": "text", "text": prompt}]
        if uploaded:
            raw = uploaded.read()
            b64 = base64.b64encode(raw).decode("utf-8")
            data_url = f"data:image/jpeg;base64,{b64}"
            user_content.append({"type": "image_url", "image_url": {"url": data_url}})
        response = co.chat(
            model=selected_model,
            messages=[{"role": "user", "content": user_content}]
        )
        content_items = response.message.content
        reply = "".join(getattr(item, 'text', '') for item in content_items)
        st.session_state.messages.append({"role": "assistant", "content": reply})
        with st.chat_message("assistant", avatar=AI_PFP):
            st.markdown(reply)
    except Exception as e:
        st.error(f"Error: {str(e)}")