Spaces:
Running
Running
| import gradio as gr | |
| from mistralai.client import MistralClient | |
| from mistralai.models.chat_completion import ChatMessage | |
| def get_stream_chat_completion( | |
| message, chat_history, model, api_key, system=None, **kwargs | |
| ): | |
| messages = [] | |
| if system is not None: | |
| messages.append(ChatMessage(role="system", content=system)) | |
| for chat in chat_history: | |
| human_message, bot_message = chat | |
| messages.extend( | |
| ( | |
| ChatMessage(role="user", content=human_message), | |
| ChatMessage(role="assistant", content=bot_message), | |
| ) | |
| ) | |
| messages.append(ChatMessage(role="user", content=message)) | |
| client = MistralClient(api_key=api_key) | |
| for chunk in client.chat_stream( | |
| model=model, | |
| messages=messages, | |
| **kwargs, | |
| ): | |
| if chunk.choices[0].delta.content is not None: | |
| yield chunk.choices[0].delta.content | |
| def respond_stream( | |
| message, | |
| chat_history, | |
| api_key, | |
| model, | |
| temperature, | |
| top_p, | |
| max_tokens, | |
| system, | |
| ): | |
| response = "" | |
| received_anything = False | |
| for chunk in get_stream_chat_completion( | |
| message=message, | |
| chat_history=chat_history, | |
| model=model, | |
| api_key=api_key, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=int(max_tokens), | |
| system=system if system else None, | |
| ): | |
| response += chunk | |
| yield response | |
| received_anything = True | |
| if not received_anything: | |
| gr.Warning("Error: Invalid API Key") | |
| yield "" | |
| css = """ | |
| .header-text p {line-height: 80px !important; text-align: left; font-size: 26px;} | |
| .header-logo {text-align: left;} | |
| .image-container img {max-width: 80px; height: auto;} | |
| """ | |
| with gr.Blocks(title="Mistral Playground", css=css) as mistral_playground: | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=80): | |
| gr.Image("tt-logo.jpg", show_download_button=False, show_share_button=False, interactive=False, show_label=False, elem_id="thinktecture-logo", container=False) | |
| with gr.Column(scale=11): | |
| gr.Markdown("Thinktecture Mistral AI Playground", elem_classes="header-text") | |
| with gr.Row(variant='panel'): | |
| with gr.Column(scale=5): | |
| api_key = gr.Textbox(type='password', placeholder='Your Mistral AI API key', lines=1, label="Mistral AI API Key") | |
| with gr.Column(scale=7): | |
| model = gr.Radio( | |
| label="Mistral AI Model", | |
| choices=[["7B","open-mistral-7b"], ["8x7B","open-mixtral-8x7b"], ["Small","mistral-small-latest"], ["Medium","mistral-medium-latest"], ["8x22B","open-mixtral-8x22b"], ["Large","mistral-large-latest"], ["Codestral","codestral-latest"]], | |
| value="mistral-large-latest", | |
| ) | |
| with gr.Row(variant='panel'): | |
| temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, step=0.1, value=0.1) | |
| top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.95) | |
| max_tokens = gr.Slider(label="Max Tokens", minimum=1000, maximum=32000, step=1000, value=8000) | |
| with gr.Row(variant='panel'): | |
| system = gr.Textbox(lines=2, label="System Message", value="You are a helpful AI assistant") | |
| gr.ChatInterface( | |
| respond_stream, | |
| chatbot=gr.Chatbot(render=False, height=500, layout="panel"), | |
| additional_inputs=[ | |
| api_key, | |
| model, | |
| temperature, | |
| top_p, | |
| max_tokens, | |
| system, | |
| ], | |
| ) | |
| with gr.Row(): | |
| gr.HTML(value="<p style='margin-top: 1rem; margin-bottom: 1rem; text-align: center;'>Developed by Marco Frodl, Principal Consultant for Generative AI @ <a href='https://go.mfr.one/tt-en' _target='blank'>Thinktecture AG</a> -- Released 06/09/2024 -- More about me on my <a href='https://go.mfr.one/marcofrodl-en' _target='blank'>profile page</a></p>") | |
| mistral_playground.launch() |