mirror of
https://github.com/0xSojalSec/airllm.git
synced 2026-04-28 09:30:07 +00:00
315 lines
9.9 KiB
Plaintext
315 lines
9.9 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6d8683b7-0fab-4937-b7ad-72d70a0260ac",
|
|
"metadata": {},
|
|
"source": [
|
|
"# tokenizer"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "d33c6c4e-9fd7-4850-8800-12ac35a867a0",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Requirement already satisfied: wonderwords in /usr/local/anaconda3/envs/ghostaienv/lib/python3.8/site-packages (2.2.0)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"!pip install wonderwords"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "591305e5-0459-4f0b-9968-f77d207d0172",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os, json\n",
|
|
"from tqdm import tqdm"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "3cca5b71-75c5-44bc-9e80-330c93915f3d",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. If you see this, DO NOT PANIC! This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=True`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from transformers import LlamaTokenizer\n",
|
|
"import torch\n",
|
|
"\n",
|
|
"base_model = \"huggyllama/llama-13b\"\n",
|
|
"tokenizer = LlamaTokenizer.from_pretrained(base_model,\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "041ff3ce-d593-4f7d-be0e-c5488aeb9156",
|
|
"metadata": {},
|
|
"source": [
|
|
"# gen topic eval dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "97a48cc7-7c41-4e7d-96ca-4771472a3e81",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import random\n",
|
|
"\n",
|
|
"np.random.seed(42) "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1409125a-2030-4067-9a33-8612c4cd668b",
|
|
"metadata": {},
|
|
"source": [
|
|
"# gen lines eval dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "4bf7c52c-ce66-483e-9a6a-c6067d1dbdeb",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"def generate_line_index(num_line, idx_opt):\n",
|
|
" if idx_opt == \"LRT-ABCindex\":\n",
|
|
" ingredients = [\"A\", \"B\", \"C\", \"D\", \"E\", \"F\"]\n",
|
|
"\n",
|
|
" start = 6\n",
|
|
" comb = list(itertools.product(ingredients, repeat=start))\n",
|
|
" while len(comb) < num_line:\n",
|
|
" start += 1\n",
|
|
" comb = list(itertools.product(ingredients, repeat=start))\n",
|
|
" \n",
|
|
" comb = [\"\".join(i) for i in comb]\n",
|
|
"\n",
|
|
" return comb[:num_line]\n",
|
|
" elif idx_opt == \"LRT-UUID\":\n",
|
|
" comb = []\n",
|
|
" for i in range(num_line):\n",
|
|
" comb.append(str(uuid.uuid4()))\n",
|
|
" \n",
|
|
" return comb\n",
|
|
" elif idx_opt == \"LRT-NL\":\n",
|
|
" import wonderwords\n",
|
|
"\n",
|
|
" w = wonderwords.RandomWord()\n",
|
|
" adjs = w.random_words(num_line, include_categories=[\"noun\"])\n",
|
|
" nouns = w.random_words(num_line, include_categories=[\"noun\"])\n",
|
|
"\n",
|
|
" comb = []\n",
|
|
" for i, (adj, noun) in enumerate(zip(adjs, nouns)):\n",
|
|
" comb.append(f\"{adj}-{noun}\")\n",
|
|
" \n",
|
|
" return comb\n",
|
|
" \n",
|
|
"def retrieve_expected(lines, random_line_pos):\n",
|
|
" correct_line = lines[random_line_pos]\n",
|
|
" expected_number = re.search(\"<\\d+>\", correct_line)\n",
|
|
" if expected_number is not None:\n",
|
|
" expected_number = int(expected_number.group()[1:-1])\n",
|
|
" else:\n",
|
|
" print(f\"Got unparsable line: {correct_line}\")\n",
|
|
"\n",
|
|
" return expected_number, correct_line\n",
|
|
"\n",
|
|
"def generate_prompt_from_lines(lines):\n",
|
|
" prompt = \"\"\n",
|
|
" for l in lines:\n",
|
|
" prompt += l\n",
|
|
" \n",
|
|
" return prompt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "bbfe0587-b2ad-4334-88a7-5f8a62b17f30",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" 0%| | 0/20 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (92263 > 2048). Running this sequence through the model will result in indexing errors\n",
|
|
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00, 2.18it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"3687.8080000000004"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import random, re\n",
|
|
"\n",
|
|
"RECORD_COUNT = 20\n",
|
|
"\n",
|
|
"ROWS = [4000]\n",
|
|
"output_dir = \".\"\n",
|
|
"\n",
|
|
"avg_len = 0\n",
|
|
"\n",
|
|
"for n in ROWS:\n",
|
|
" output_path = os.path.join(output_dir, f\"{n}_lines_en.jsonl\")\n",
|
|
" f = open(output_path, \"w\", encoding=\"utf-8\")\n",
|
|
"\n",
|
|
" for i in tqdm(list(range(RECORD_COUNT))): \n",
|
|
" prompt_header = \"Below is a record of lines I want you to remember. \" + \\\n",
|
|
" \"Each line begins with 'line <line index>' and contains \" + \\\n",
|
|
" \"a '<REGISTER_CONTENT>' at the end of the line as a numerical value. \" + \\\n",
|
|
" \"For each line index, memorize its corresponding <REGISTER_CONTENT>. At \" + \\\n",
|
|
" \"the end of the record, I will ask you to retrieve the corresponding \" + \\\n",
|
|
" \"<REGISTER_CONTENT> of a certain line index. Now the record start:\\n\\n\"\n",
|
|
" lines = []\n",
|
|
"\n",
|
|
" line_idx_opt = \"LRT-NL\"\n",
|
|
"\n",
|
|
" if line_idx_opt == \"LRT\":\n",
|
|
" line_idxes = list(range(1, n + 1))\n",
|
|
" lines.extend([f\"line {i}: REGISTER_CONTENT is <{random.randint(1, 50000)}>\\n\" for i in line_idxes])\n",
|
|
" random_idx = random.randint(1, n)\n",
|
|
" random_num = random_idx - 1\n",
|
|
" else:\n",
|
|
" line_idxes = generate_line_index(n, line_idx_opt)\n",
|
|
" lines.extend([f\"line {i}: REGISTER_CONTENT is <{random.randint(1, 50000)}>\\n\" for i in line_idxes])\n",
|
|
" random_num = random.randint(0, len(line_idxes)-1)\n",
|
|
" random_idx = line_idxes[random_num]\n",
|
|
"\n",
|
|
" expected_number, correct_line = retrieve_expected(lines, random_num)\n",
|
|
" lines.insert(0, f\"{prompt_header}\")\n",
|
|
" prompt_postfix = f\"\\nNow the record is over. Tell me what is the <REGISTER_CONTENT> in line {random_idx}? I need the number.\"\n",
|
|
"\n",
|
|
" prompt = generate_prompt_from_lines(lines)\n",
|
|
"\n",
|
|
" prompt_len = len(tokenizer.encode(prompt))\n",
|
|
"\n",
|
|
"\n",
|
|
" avg_len += prompt_len/500\n",
|
|
"\n",
|
|
" \n",
|
|
" output = {\n",
|
|
" \"random_idx\": (random_idx, random_num), # this is the line to retrieve\n",
|
|
" \"expected_number\": expected_number,\n",
|
|
" \"num_lines\": n,\n",
|
|
" \"prompt_len\":prompt_len,\n",
|
|
" \"correct_line\": correct_line,\n",
|
|
" \"prompt_postfix\": prompt_postfix,\n",
|
|
" \"prompt\": prompt}\n",
|
|
"\n",
|
|
" json.dump(output, f, ensure_ascii=False)\n",
|
|
" f.write(\"\\n\")\n",
|
|
" f.close()\n",
|
|
"\n",
|
|
"\n",
|
|
"avg_len"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b3b3d86e-6da6-44f2-887a-ae1374961fa0",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"!head -n 1 {n}_lines_en.jsonl"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "cf86aebf-b78d-43bf-8aea-d9ced7676855",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" 20 4000_lines_en.jsonl\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"!wc -l {n}_lines_en.jsonl"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2dbb8bc4-8449-43d3-b32b-fe0072a815e7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.13"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|