{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "84dd8e5e-c3f9-406f-b7d3-8c6f1a412d82", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting datasets\n", " Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)\n", "Requirement already satisfied: filelock in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from datasets) (3.17.0)\n", "Requirement already satisfied: numpy>=1.17 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from datasets) (1.26.4)\n", "Requirement already satisfied: pyarrow>=15.0.0 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from datasets) (19.0.0)\n", "Collecting dill<0.3.9,>=0.3.0 (from datasets)\n", " Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n", "Requirement already satisfied: pandas in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from datasets) (2.2.3)\n", "Requirement already satisfied: requests>=2.32.2 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from datasets) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.66.3 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from datasets) (4.67.1)\n", "Collecting xxhash (from datasets)\n", " Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n", "Collecting multiprocess<0.70.17 (from datasets)\n", " Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)\n", "Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)\n", " Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)\n", "Collecting aiohttp (from datasets)\n", " Downloading aiohttp-3.11.16-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)\n", "Collecting huggingface-hub>=0.24.0 (from datasets)\n", " Using cached huggingface_hub-0.30.2-py3-none-any.whl.metadata (13 kB)\n", "Requirement already satisfied: packaging in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from datasets) (21.3)\n", "Requirement already satisfied: pyyaml>=5.1 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from datasets) (6.0.2)\n", "Collecting aiohappyeyeballs>=2.3.0 (from aiohttp->datasets)\n", " Downloading aiohappyeyeballs-2.6.1-py3-none-any.whl.metadata (5.9 kB)\n", "Collecting aiosignal>=1.1.2 (from aiohttp->datasets)\n", " Downloading aiosignal-1.3.2-py2.py3-none-any.whl.metadata (3.8 kB)\n", "Collecting async-timeout<6.0,>=4.0 (from aiohttp->datasets)\n", " Downloading async_timeout-5.0.1-py3-none-any.whl.metadata (5.1 kB)\n", "Requirement already satisfied: attrs>=17.3.0 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from aiohttp->datasets) (23.2.0)\n", "Collecting frozenlist>=1.1.1 (from aiohttp->datasets)\n", " Downloading frozenlist-1.5.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)\n", "Collecting multidict<7.0,>=4.5 (from aiohttp->datasets)\n", " Downloading multidict-6.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.3 kB)\n", "Collecting propcache>=0.2.0 (from aiohttp->datasets)\n", " Downloading propcache-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)\n", "Collecting yarl<2.0,>=1.17.0 (from aiohttp->datasets)\n", " Downloading yarl-1.19.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (71 kB)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from huggingface-hub>=0.24.0->datasets) (4.12.2)\n", "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from packaging->datasets) (3.2.1)\n", "Requirement already satisfied: charset_normalizer<4,>=2 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from requests>=2.32.2->datasets) (3.4.1)\n", "Requirement already satisfied: idna<4,>=2.5 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from requests>=2.32.2->datasets) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from requests>=2.32.2->datasets) (2.3.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from requests>=2.32.2->datasets) (2025.1.31)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from pandas->datasets) (2.9.0.post0)\n", "Requirement already satisfied: pytz>=2020.1 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: tzdata>=2022.7 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from pandas->datasets) (2025.1)\n", "Requirement already satisfied: six>=1.5 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\n", "Downloading datasets-3.5.0-py3-none-any.whl (491 kB)\n", "Downloading dill-0.3.8-py3-none-any.whl (116 kB)\n", "Downloading fsspec-2024.12.0-py3-none-any.whl (183 kB)\n", "Downloading aiohttp-3.11.16-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m58.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hUsing cached huggingface_hub-0.30.2-py3-none-any.whl (481 kB)\n", "Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n", "Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n", "Downloading aiohappyeyeballs-2.6.1-py3-none-any.whl (15 kB)\n", "Downloading aiosignal-1.3.2-py2.py3-none-any.whl (7.6 kB)\n", "Downloading async_timeout-5.0.1-py3-none-any.whl (6.2 kB)\n", "Downloading frozenlist-1.5.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (241 kB)\n", "Downloading multidict-6.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (219 kB)\n", "Downloading propcache-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (206 kB)\n", "Downloading yarl-1.19.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (334 kB)\n", "Installing collected packages: xxhash, propcache, multidict, fsspec, frozenlist, dill, async-timeout, aiohappyeyeballs, yarl, multiprocess, huggingface-hub, aiosignal, aiohttp, datasets\n", " Attempting uninstall: fsspec\n", " Found existing installation: fsspec 2025.2.0\n", " Uninstalling fsspec-2025.2.0:\n", " Successfully uninstalled fsspec-2025.2.0\n", " Attempting uninstall: dill\n", " Found existing installation: dill 0.3.9\n", " Uninstalling dill-0.3.9:\n", " Successfully uninstalled dill-0.3.9\n", " Attempting uninstall: multiprocess\n", " Found existing installation: multiprocess 0.70.17\n", " Uninstalling multiprocess-0.70.17:\n", " Successfully uninstalled multiprocess-0.70.17\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "pathos 0.3.3 requires dill>=0.3.9, but you have dill 0.3.8 which is incompatible.\n", "pathos 0.3.3 requires multiprocess>=0.70.17, but you have multiprocess 0.70.16 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0mSuccessfully installed aiohappyeyeballs-2.6.1 aiohttp-3.11.16 aiosignal-1.3.2 async-timeout-5.0.1 datasets-3.5.0 dill-0.3.8 frozenlist-1.5.0 fsspec-2024.12.0 huggingface-hub-0.30.2 multidict-6.4.3 multiprocess-0.70.16 propcache-0.3.1 xxhash-3.5.0 yarl-1.19.0\n", "Collecting accelerate>=0.26.0\n", " Downloading accelerate-1.6.0-py3-none-any.whl.metadata (19 kB)\n", "Requirement already satisfied: numpy<3.0.0,>=1.17 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from accelerate>=0.26.0) (1.26.4)\n", "Requirement already satisfied: packaging>=20.0 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from accelerate>=0.26.0) (21.3)\n", "Requirement already satisfied: psutil in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from accelerate>=0.26.0) (6.1.1)\n", "Requirement already satisfied: pyyaml in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from accelerate>=0.26.0) (6.0.2)\n", "Requirement already satisfied: torch>=2.0.0 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from accelerate>=0.26.0) (2.2.2)\n", "Requirement already satisfied: huggingface-hub>=0.21.0 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from accelerate>=0.26.0) (0.30.2)\n", "Collecting safetensors>=0.4.3 (from accelerate>=0.26.0)\n", " Downloading safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)\n", "Requirement already satisfied: filelock in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate>=0.26.0) (3.17.0)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate>=0.26.0) (2024.12.0)\n", "Requirement already satisfied: requests in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate>=0.26.0) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.42.1 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate>=0.26.0) (4.67.1)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate>=0.26.0) (4.12.2)\n", "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from packaging>=20.0->accelerate>=0.26.0) (3.2.1)\n", "Requirement already satisfied: sympy in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from torch>=2.0.0->accelerate>=0.26.0) (1.13.3)\n", "Requirement already satisfied: networkx in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from torch>=2.0.0->accelerate>=0.26.0) (3.4.2)\n", "Requirement already satisfied: jinja2 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from torch>=2.0.0->accelerate>=0.26.0) (3.1.5)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from jinja2->torch>=2.0.0->accelerate>=0.26.0) (3.0.2)\n", "Requirement already satisfied: charset_normalizer<4,>=2 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate>=0.26.0) (3.4.1)\n", "Requirement already satisfied: idna<4,>=2.5 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate>=0.26.0) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate>=0.26.0) (2.3.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate>=0.26.0) (2025.1.31)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from sympy->torch>=2.0.0->accelerate>=0.26.0) (1.3.0)\n", "Downloading accelerate-1.6.0-py3-none-any.whl (354 kB)\n", "Downloading safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (471 kB)\n", "Installing collected packages: safetensors, accelerate\n", "Successfully installed accelerate-1.6.0 safetensors-0.5.3\n", "Collecting sentence-transformers\n", " Downloading sentence_transformers-4.1.0-py3-none-any.whl.metadata (13 kB)\n", "Collecting transformers<5.0.0,>=4.41.0 (from sentence-transformers)\n", " Downloading transformers-4.51.3-py3-none-any.whl.metadata (38 kB)\n", "Requirement already satisfied: tqdm in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from sentence-transformers) (4.67.1)\n", "Requirement already satisfied: torch>=1.11.0 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from sentence-transformers) (2.2.2)\n", "Requirement already satisfied: scikit-learn in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from sentence-transformers) (1.6.1)\n", "Requirement already satisfied: scipy in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from sentence-transformers) (1.15.1)\n", "Requirement already satisfied: huggingface-hub>=0.20.0 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from sentence-transformers) (0.30.2)\n", "Requirement already satisfied: Pillow in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from sentence-transformers) (11.1.0)\n", "Requirement already satisfied: typing_extensions>=4.5.0 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from sentence-transformers) (4.12.2)\n", "Requirement already satisfied: filelock in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from huggingface-hub>=0.20.0->sentence-transformers) (3.17.0)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from huggingface-hub>=0.20.0->sentence-transformers) (2024.12.0)\n", "Requirement already satisfied: packaging>=20.9 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from huggingface-hub>=0.20.0->sentence-transformers) (21.3)\n", "Requirement already satisfied: pyyaml>=5.1 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from huggingface-hub>=0.20.0->sentence-transformers) (6.0.2)\n", "Requirement already satisfied: requests in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from huggingface-hub>=0.20.0->sentence-transformers) (2.32.3)\n", "Requirement already satisfied: sympy in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from torch>=1.11.0->sentence-transformers) (1.13.3)\n", "Requirement already satisfied: networkx in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from torch>=1.11.0->sentence-transformers) (3.4.2)\n", "Requirement already satisfied: jinja2 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from torch>=1.11.0->sentence-transformers) (3.1.5)\n", "Requirement already satisfied: numpy>=1.17 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from transformers<5.0.0,>=4.41.0->sentence-transformers) (1.26.4)\n", "Collecting regex!=2019.12.17 (from transformers<5.0.0,>=4.41.0->sentence-transformers)\n", " Downloading regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)\n", "Collecting tokenizers<0.22,>=0.21 (from transformers<5.0.0,>=4.41.0->sentence-transformers)\n", " Downloading tokenizers-0.21.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)\n", "Requirement already satisfied: safetensors>=0.4.3 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from transformers<5.0.0,>=4.41.0->sentence-transformers) (0.5.3)\n", "Requirement already satisfied: joblib>=1.2.0 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from scikit-learn->sentence-transformers) (1.4.2)\n", "Requirement already satisfied: threadpoolctl>=3.1.0 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from scikit-learn->sentence-transformers) (3.5.0)\n", "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from packaging>=20.9->huggingface-hub>=0.20.0->sentence-transformers) (3.2.1)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from jinja2->torch>=1.11.0->sentence-transformers) (3.0.2)\n", "Requirement already satisfied: charset_normalizer<4,>=2 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from requests->huggingface-hub>=0.20.0->sentence-transformers) (3.4.1)\n", "Requirement already satisfied: idna<4,>=2.5 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from requests->huggingface-hub>=0.20.0->sentence-transformers) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from requests->huggingface-hub>=0.20.0->sentence-transformers) (2.3.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from requests->huggingface-hub>=0.20.0->sentence-transformers) (2025.1.31)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages (from sympy->torch>=1.11.0->sentence-transformers) (1.3.0)\n", "Downloading sentence_transformers-4.1.0-py3-none-any.whl (345 kB)\n", "Downloading transformers-4.51.3-py3-none-any.whl (10.4 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.4/10.4 MB\u001b[0m \u001b[31m141.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (781 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m781.7/781.7 kB\u001b[0m \u001b[31m104.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading tokenizers-0.21.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.0 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.0/3.0 MB\u001b[0m \u001b[31m185.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: regex, tokenizers, transformers, sentence-transformers\n", "Successfully installed regex-2024.11.6 sentence-transformers-4.1.0 tokenizers-0.21.1 transformers-4.51.3\n" ] } ], "source": [ "!pip install datasets\n", "\n", "!pip install 'accelerate>=0.26.0'\n", "\n", "!pip install -U sentence-transformers" ] }, { "cell_type": "code", "execution_count": 2, "id": "966e5a69-2358-42fd-b8f3-484c6be8d0ca", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ca88347fcd684ff28a453a43d8c324f7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "modules.json: 0%| | 0.00/349 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "149a0a1c63f5494db63f2e840ce0d2af", "version_major": 2, "version_minor": 0 }, "text/plain": [ "config_sentence_transformers.json: 0%| | 0.00/124 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "47f0c6dac99843daa984b4f0ea4df563", "version_major": 2, "version_minor": 0 }, "text/plain": [ "README.md: 0%| | 0.00/94.6k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "563540a1c8364b058d17a1bfaeed68bb", "version_major": 2, "version_minor": 0 }, "text/plain": [ "sentence_bert_config.json: 0%| | 0.00/52.0 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9dc50352e36244b2ae1ddeb5f67d2129", "version_major": 2, "version_minor": 0 }, "text/plain": [ "config.json: 0%| | 0.00/777 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a8b8964d99ec4343b1953433e0b5082d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model.safetensors: 0%| | 0.00/438M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8114dd2a109e4f648337c98448dee88a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer_config.json: 0%| | 0.00/366 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "79fa2188f3234329944213e32238bc47", "version_major": 2, "version_minor": 0 }, "text/plain": [ "vocab.txt: 0%| | 0.00/232k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bb94f20f0a704271a3e58fc8536d9469", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer.json: 0%| | 0.00/711k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e2ab21c4f2db45b5a2c91df999ae4dca", "version_major": 2, "version_minor": 0 }, "text/plain": [ "special_tokens_map.json: 0%| | 0.00/125 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9bb75c8db86541d3a7c9b36e41958e9e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "config.json: 0%| | 0.00/190 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from sentence_transformers import SentenceTransformer, CrossEncoder\n", "embeddings_model_source = 'BAAI/bge-base-en-v1.5'\n", "cross_encodings_model_source = 'cross-encoder/ms-marco-MiniLM-L6-v2'\n", "embeddings_model = SentenceTransformer(embeddings_model_source)\n", "# cross_encodings_model = CrossEncoder(cross_encodings_model_source)" ] }, { "cell_type": "code", "execution_count": 3, "id": "39024f37-a423-4366-a597-b160db5efd18", "metadata": {}, "outputs": [], "source": [ "import boto3\n", "import json\n", "import io\n", "from sentence_transformers import SentenceTransformer, InputExample, losses\n", "from torch.utils.data import DataLoader\n", "from datasets import Dataset\n", "import logging\n", "from sentence_transformers import LoggingHandler\n", "import random\n", "from sentence_transformers import InputExample, losses\n", "from sentence_transformers.evaluation import TripletEvaluator\n", "from torch.cuda.amp import autocast\n", "\n", "\n", "logging.basicConfig(format='%(asctime)s - %(message)s',\n", " datefmt='%Y-%m-%d %H:%M:%S',\n", " level=logging.INFO,\n", " handlers=[LoggingHandler()])\n", "# Now run embeddings_model.fit(...) and you’ll see training loss printed during training.\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "83e6126d-d853-431d-b30a-55c203e73de1", "metadata": {}, "outputs": [], "source": [ "# === Config ===\n", "s3_bucket = 'big-cloud-country-sandbox'\n", "s3_key = 'fine-tune-data/embeddings/embedding_data_hard_negs_4.jsonl' # or train.jsonl\n", "local_model = 'BAAI/bge-base-en-v1.5'\n", "output_path = './finetuned-bge-nuclear-4-3'\n", "batch_size = 32\n", "num_epochs = 5\n", "grad_accumulation_steps = 4" ] }, { "cell_type": "code", "execution_count": 13, "id": "bb38b7e9-38e3-4bb1-9c68-1983387268f4", "metadata": {}, "outputs": [], "source": [ "# === Step 1: Load training data from S3 ===\n", "s3 = boto3.client('s3', region_name='us-east-2')\n", "obj = s3.get_object(Bucket=s3_bucket, Key=s3_key)\n", "lines = obj['Body'].read().decode('utf-8').splitlines()\n", "\n", "# Shuffle lines\n", "random.seed(42) # For reproducibility\n", "random.shuffle(lines)" ] }, { "cell_type": "code", "execution_count": 14, "id": "010a8bec-8ed7-47c3-aefe-72db2c5dcc03", "metadata": {}, "outputs": [], "source": [ "# === Step 2: Split into training and validation sets ===\n", "split_ratio = 0.8\n", "split_index = int(len(lines) * split_ratio)\n", "train_lines = lines[:split_index]\n", "val_lines = lines[split_index:]" ] }, { "cell_type": "code", "execution_count": 15, "id": "06a60caa-fc35-4586-b276-4838eb32fbbc", "metadata": {}, "outputs": [], "source": [ "# === Step 3: Convert to InputExamples ===\n", "train_examples = []\n", "for line in train_lines:\n", " record = json.loads(line)\n", " train_examples.append(InputExample(texts=[\n", " record['query'], record['positive'], record['negative']\n", " ]))\n", "\n", "val_anchors, val_positives, val_negatives = [], [], []\n", "for line in val_lines:\n", " record = json.loads(line)\n", " val_anchors.append(record['query'])\n", " val_positives.append(record['positive'])\n", " val_negatives.append(record['negative'])" ] }, { "cell_type": "code", "execution_count": 16, "id": "6c273e47-6f50-4b53-8339-97a290e1d50f", "metadata": {}, "outputs": [], "source": [ "# === Step 4: Set up train dataloader ===\n", "train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size)\n", "\n", "# === Step 5: Define loss function ===\n", "# train_loss = losses.TripletLoss(model=embeddings_model)\n", "train_loss = losses.MultipleNegativesRankingLoss(model=embeddings_model)\n", "\n", "\n", "# === Step 6: Set up evaluator ===\n", "evaluator = TripletEvaluator(\n", " val_anchors,\n", " val_positives,\n", " val_negatives,\n", " name='validation',\n", " batch_size=8,\n", ")\n", "evaluator.write_csv = True" ] }, { "cell_type": "code", "execution_count": 17, "id": "770bd27a-a93a-4016-a6e3-836930d90688", "metadata": { "scrolled": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9d48de6ea4cf49f88541f22f109c94bd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Computing widget examples: 0%| | 0/1 [00:00, ?example/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
| Step | \n", "Training Loss | \n", "Validation Loss | \n", "Validation Cosine Accuracy | \n", "
|---|---|---|---|
| 200 | \n", "No log | \n", "No log | \n", "0.927644 | \n", "
| 270 | \n", "No log | \n", "No log | \n", "0.930891 | \n", "
| 400 | \n", "No log | \n", "No log | \n", "0.931354 | \n", "
| 540 | \n", "0.566200 | \n", "No log | \n", "0.935529 | \n", "
| 600 | \n", "0.566200 | \n", "No log | \n", "0.935529 | \n", "
| 800 | \n", "0.566200 | \n", "No log | \n", "0.937384 | \n", "
| 810 | \n", "0.566200 | \n", "No log | \n", "0.935529 | \n", "
| 1000 | \n", "0.242900 | \n", "No log | \n", "0.935993 | \n", "
| 1080 | \n", "0.242900 | \n", "No log | \n", "0.935529 | \n", "
| 1200 | \n", "0.242900 | \n", "No log | \n", "0.935529 | \n", "
| 1350 | \n", "0.242900 | \n", "No log | \n", "0.936920 | \n", "
"
],
"text/plain": [
"