{ "cells": [ { "cell_type": "markdown", "id": "b7eb261b", "metadata": { "id": "b7eb261b" }, "source": [ "# NanoChat Easy - SFT Training\n" ] }, { "cell_type": "markdown", "id": "8b8a04a8", "metadata": { "id": "8b8a04a8" }, "source": [ "## Import model and tokenizer\n" ] }, { "cell_type": "code", "execution_count": null, "id": "3e48247c", "metadata": { "id": "3e48247c", "outputId": "882fcf01-34fb-4123-e84c-deefdf477814" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "`torch_dtype` is deprecated! Use `dtype` instead!\n" ] } ], "source": [ "import torch\n", "from torch.utils.data import DataLoader\n", "from datasets import load_dataset\n", "from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup\n", "\n", "\n", "model_id = \"karpathy/nanochat-d32\"\n", "revision = \"refs/pr/1\"\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_id,\n", " revision=revision,\n", " torch_dtype=torch.bfloat16 if device.type == \"cuda\" else torch.float32,\n", ").to(device)\n" ] }, { "cell_type": "markdown", "id": "4810af1a", "metadata": { "id": "4810af1a" }, "source": [ "## Demo the model\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b3e81aa9", "metadata": { "id": "b3e81aa9", "outputId": "1cde7e69-7ff1-4bfe-aa9f-9ded20249d82" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================================================================================\n", "TEST 1: Plain Autoregressive Prompt\n", "================================================================================\n", "Prompt: The Eiffel Tower stands in Paris and\n", "\n", "Generated: is one of the most famous landmarks in the world. It is located on the Champ de Mars in the heart of the city. The tower was built for the 1889 World's Fair. It was designed by the French engineer Gustave Eiffel and took 2 years to build. The Eiffel Tower stands 324 meters\n", "================================================================================\n" ] } ], "source": [ "print(\"=\" * 80)\n", "print(\"TEST 1: Plain Autoregressive Prompt\")\n", "print(\"=\" * 80)\n", "prompt = \"The Eiffel Tower stands in Paris and\"\n", "test_inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n", "\n", "\n", "with torch.no_grad():\n", " test_outputs = model.generate(\n", " **test_inputs,\n", " max_new_tokens=64,\n", " do_sample=False,\n", " pad_token_id=tokenizer.pad_token_id,\n", " )\n", "\n", "generated_tokens = test_outputs[0, test_inputs[\"input_ids\"].shape[1] :]\n", "print(f\"Prompt: {prompt}\")\n", "print(f\"\\nGenerated: {tokenizer.decode(generated_tokens, skip_special_tokens=True)}\")\n", "print(\"=\" * 80)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "8e7b275c", "metadata": { "id": "8e7b275c", "outputId": "719e986e-61b4-4fd5-db15-4a9ef8f97396" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================================================================================\n", "TEST 2: Chat Template\n", "================================================================================\n", "Formatted prompt: <|bos|><|user_start|>What is the capital of France?<|user_end|><|assistant_start|>\n", "Input IDs: [65527, 65528, 1442, 309, 261, 3429, 281, 4215, 63, 65529, 65530]\n", "\n", "Generated: The capital of France is Paris.<|assistant_end|>\n", "================================================================================\n" ] } ], "source": [ "print(\"=\" * 80)\n", "print(\"TEST 2: Chat Template\")\n", "print(\"=\"*80)\n", "conversation = [\n", " {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n", "]\n", "\n", "inputs = tokenizer.apply_chat_template(\n", " conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors=\"pt\"\n", ").to(device)\n", "\n", "print(f\"Formatted prompt: {tokenizer.decode(inputs['input_ids'][0])}\")\n", "print(f\"Input IDs: {inputs['input_ids'][0].tolist()}\")\n", "\n", "with torch.no_grad():\n", " outputs = model.generate(\n", " **inputs,\n", " max_new_tokens=64,\n", " do_sample=False\n", " )\n", "\n", "generated_tokens = outputs[0, inputs[\"input_ids\"].shape[1] :]\n", "print(f\"\\nGenerated: {tokenizer.decode(generated_tokens)}\")\n", "print(\"=\" * 80)\n" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 5 }