Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| import spaces | |
| import subprocess | |
| import sys | |
| # Install specific transformers version | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "transformers==4.48.3"]) | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # Load model and tokenizer | |
| model_name = "nvidia/NVIDIA-Nemotron-Nano-9B-v2" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = None | |
| def load_model(): | |
| global model | |
| if model is None: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| device_map="auto" | |
| ) | |
| return model | |
| def generate_response(message, history, enable_reasoning, temperature, top_p, max_tokens): | |
| """Generate response from the model""" | |
| # Prepare messages with reasoning control | |
| messages = [] | |
| # Add system message based on reasoning setting | |
| if enable_reasoning: | |
| messages.append({"role": "system", "content": "/think"}) | |
| else: | |
| messages.append({"role": "system", "content": "/no_think"}) | |
| # Add conversation history | |
| for user_msg, assistant_msg in history: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| # Add current message | |
| messages.append({"role": "user", "content": message}) | |
| # Load model if needed | |
| model = load_model() | |
| # Tokenize the conversation | |
| tokenized_chat = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors="pt" | |
| ).to(model.device) | |
| # Set generation parameters based on reasoning mode | |
| if enable_reasoning: | |
| # Recommended settings for reasoning | |
| generation_kwargs = { | |
| "temperature": temperature if temperature > 0 else 0.6, | |
| "top_p": top_p if top_p < 1 else 0.95, | |
| "do_sample": True, | |
| "max_new_tokens": max_tokens, | |
| "eos_token_id": tokenizer.eos_token_id | |
| } | |
| else: | |
| # Greedy search for non-reasoning | |
| generation_kwargs = { | |
| "do_sample": False, | |
| "max_new_tokens": max_tokens, | |
| "eos_token_id": tokenizer.eos_token_id | |
| } | |
| # Generate response | |
| with torch.no_grad(): | |
| outputs = model.generate(tokenized_chat, **generation_kwargs) | |
| # Decode and extract the assistant's response | |
| generated_tokens = outputs[0][tokenized_chat.shape[-1]:] # Get only new tokens | |
| response = tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| return response | |
| # Create Gradio interface | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # NVIDIA Nemotron Nano 9B v2 Chatbot | |
| This chatbot uses the NVIDIA Nemotron Nano 9B v2 model with optional reasoning capabilities. | |
| - **Enable Reasoning**: Activates the model's chain-of-thought reasoning (/think mode) | |
| - **Disable Reasoning**: Uses direct response generation (/no_think mode) | |
| **Note:** Using transformers version 4.48.3 as recommended by the model documentation. | |
| """ | |
| ) | |
| chatbot = gr.Chatbot(height=500) | |
| msg = gr.Textbox( | |
| label="Message", | |
| placeholder="Type your message here...", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| submit = gr.Button("Send", variant="primary") | |
| clear = gr.Button("Clear") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| enable_reasoning = gr.Checkbox( | |
| label="Enable Reasoning (/think mode)", | |
| value=True, | |
| info="Enable chain-of-thought reasoning for complex queries" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.6, | |
| step=0.1, | |
| label="Temperature", | |
| info="Controls randomness (recommended: 0.6 for reasoning, ignored for non-reasoning)" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p", | |
| info="Controls diversity (recommended: 0.95 for reasoning, ignored for non-reasoning)" | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=32, | |
| maximum=2048, | |
| value=1024, | |
| step=32, | |
| label="Max New Tokens", | |
| info="Maximum number of tokens to generate (recommended: 1024+ for reasoning)" | |
| ) | |
| def user_submit(message, history): | |
| return "", history + [[message, None]] | |
| def bot_response(history, enable_reasoning, temperature, top_p, max_tokens): | |
| if not history: | |
| return history | |
| message = history[-1][0] | |
| try: | |
| response = generate_response( | |
| message, | |
| history[:-1], | |
| enable_reasoning, | |
| temperature, | |
| top_p, | |
| max_tokens | |
| ) | |
| history[-1][1] = response | |
| except Exception as e: | |
| history[-1][1] = f"Error generating response: {str(e)}" | |
| return history | |
| msg.submit( | |
| user_submit, | |
| [msg, chatbot], | |
| [msg, chatbot], | |
| queue=False | |
| ).then( | |
| bot_response, | |
| [chatbot, enable_reasoning, temperature, top_p, max_tokens], | |
| chatbot | |
| ) | |
| submit.click( | |
| user_submit, | |
| [msg, chatbot], | |
| [msg, chatbot], | |
| queue=False | |
| ).then( | |
| bot_response, | |
| [chatbot, enable_reasoning, temperature, top_p, max_tokens], | |
| chatbot | |
| ) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| # Example prompts | |
| gr.Examples( | |
| examples=[ | |
| "Write a haiku about GPUs", | |
| "Explain quantum computing in simple terms", | |
| "What is the capital of France?", | |
| "Solve this step by step: If a train travels 120 miles in 2 hours, what is its average speed?", | |
| "Write a short story about a robot learning to paint" | |
| ], | |
| inputs=msg | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |