{ "cells": [ { "cell_type": "code", "execution_count": 21, "id": "d7710e8b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`\n", "c:\\Users\\User\\Chatbot\\.venv\\Lib\\site-packages\\huggingface_hub\\file_download.py:143: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\User\\.cache\\huggingface\\hub\\models--facebook--bart-base. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n", "To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n", " warnings.warn(message)\n" ] } ], "source": [ "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n", "\n", "facebook_model_name = \"facebook/blenderbot-400M-distill\"\n", "google_model_name = \"google/flan-t5-base\"\n", "bart_model_name = \"facebook/bart-base\"\n", "\n", "\n", "# Load the models and tokenizers\n", "fb_model = AutoModelForSeq2SeqLM.from_pretrained(facebook_model_name)\n", "google_model = AutoModelForSeq2SeqLM.from_pretrained(google_model_name)\n", "bart_model = AutoModelForSeq2SeqLM.from_pretrained(bart_model_name)\n", "\n", "# Tokenizers\n", "fb_tokenizer = AutoTokenizer.from_pretrained(facebook_model_name)\n", "google_tokenizer = AutoTokenizer.from_pretrained(google_model_name)\n", "bart_tokenizer= AutoTokenizer.from_pretrained(bart_model_name)" ] }, { "cell_type": "code", "execution_count": 25, "id": "2f852aa5", "metadata": {}, "outputs": [], "source": [ "# define chat function\n", "def chat_with_bots(input_text, model_name=\"google\"):\n", " if model_name == \"google\":\n", " inputs = google_tokenizer.encode(input_text, return_tensors='pt')\n", " outputs = google_model.generate(inputs, max_new_tokens=1000)\n", " response = google_tokenizer.decode(\n", " outputs[0], \n", " skip_special_tokens=True\n", " ).strip()\n", " \n", " return response\n", " elif model_name == \"facebook\":\n", " inputs = fb_tokenizer.encode(input_text, return_tensors='pt')\n", " outputs = fb_model.generate(inputs, max_new_tokens=1000)\n", " response = fb_tokenizer.decode(\n", " outputs[0], \n", " skip_special_tokens=True\n", " ).strip()\n", " return response\n", " \n", " elif model_name == \"bart\":\n", " inputs = bart_tokenizer.encode(input_text, return_tensors='pt')\n", " outputs = bart_model.generate(inputs, max_new_tokens=1000)\n", " response = bart_tokenizer.decode(\n", " outputs[0], \n", " skip_special_tokens=True\n", " ).strip()\n", " return response\n", " else: \n", " response = \"No model selected\"\n", " return response" ] }, { "cell_type": "code", "execution_count": 29, "id": "568e3555", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Hi, my name is samantha. How are you doing this fine evening?\n" ] } ], "source": [ "print(chat_with_bots(\"hi, what your name\", model_name=\"facebook\"))" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }