Skip to content

Commit

Permalink
Style
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexKoff88 committed Oct 18, 2024
1 parent 0bf2325 commit b6220d5
Showing 1 changed file with 23 additions and 22 deletions.
45 changes: 23 additions & 22 deletions notebooks/openvino/sentence_transformer_quantization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,11 @@
"\n",
"quantizer = OVQuantizer.from_pretrained(model)\n",
"\n",
"\n",
"def preprocess_function(examples, tokenizer):\n",
" return tokenizer(examples[\"sentence\"], padding=\"max_length\", max_length=384, truncation=True)\n",
"\n",
"\n",
"calibration_dataset = quantizer.get_calibration_dataset(\n",
" \"glue\",\n",
" dataset_config_name=\"sst2\",\n",
Expand All @@ -194,13 +196,6 @@
"tokenizer.save_pretrained(int8_ptq_model_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -216,11 +211,12 @@
"source": [
"from transformers import Pipeline\n",
"import torch.nn.functional as F\n",
"import torch \n",
"import torch\n",
"\n",
"\n",
"# copied from the model card\n",
"def mean_pooling(model_output, attention_mask):\n",
" token_embeddings = model_output[0] #First element of model_output contains all token embeddings\n",
" token_embeddings = model_output[0] # First element of model_output contains all token embeddings\n",
" input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n",
" return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n",
"\n",
Expand All @@ -230,7 +226,7 @@
" # we don\"t have any hyperameters to sanitize\n",
" preprocess_kwargs = {}\n",
" return preprocess_kwargs, {}, {}\n",
" \n",
"\n",
" def preprocess(self, inputs):\n",
" encoded_inputs = self.tokenizer(inputs, padding=True, truncation=True, return_tensors=\"pt\")\n",
" return encoded_inputs\n",
Expand Down Expand Up @@ -283,7 +279,7 @@
"from datasets import load_dataset\n",
"from evaluate import load\n",
"\n",
"eval_dataset = load_dataset(\"glue\",\"stsb\",split=\"validation\")\n",
"eval_dataset = load_dataset(\"glue\", \"stsb\", split=\"validation\")\n",
"metric = load(\"glue\", \"stsb\")"
]
},
Expand Down Expand Up @@ -315,21 +311,22 @@
}
],
"source": [
"def compute_sentence_similarity(sentence_1, sentence_2,pipeline):\n",
"def compute_sentence_similarity(sentence_1, sentence_2, pipeline):\n",
" embedding_1 = pipeline(sentence_1)\n",
" embedding_2 = pipeline(sentence_2)\n",
" # compute cosine similarity between two sentences\n",
" return torch.nn.functional.cosine_similarity(embedding_1, embedding_2, dim=1)\n",
"\n",
"\n",
"def evaluate_stsb(example):\n",
" default = compute_sentence_similarity(example[\"sentence1\"], example[\"sentence2\"], vanilla_emb)\n",
" quantized = compute_sentence_similarity(example[\"sentence1\"], example[\"sentence2\"], q8_emb)\n",
" return {\n",
" \"reference\": (example[\"label\"] - 1) / (5 - 1), # rescale to [0,1]\n",
" \"default\": float(default),\n",
" \"quantized\": float(quantized),\n",
" }\n",
" default = compute_sentence_similarity(example[\"sentence1\"], example[\"sentence2\"], vanilla_emb)\n",
" quantized = compute_sentence_similarity(example[\"sentence1\"], example[\"sentence2\"], q8_emb)\n",
" return {\n",
" \"reference\": (example[\"label\"] - 1) / (5 - 1), # rescale to [0,1]\n",
" \"default\": float(default),\n",
" \"quantized\": float(quantized),\n",
" }\n",
"\n",
"\n",
"result = eval_dataset.map(evaluate_stsb)"
]
Expand All @@ -353,9 +350,13 @@
"default_acc = metric.compute(predictions=result[\"default\"], references=result[\"reference\"])\n",
"quantized = metric.compute(predictions=result[\"quantized\"], references=result[\"reference\"])\n",
"\n",
"print(\"vanilla model: pearson=\", default_acc['pearson'])\n",
"print(\"quantized model: pearson=\", quantized['pearson'])\n",
"print(\"The quantized model achieves \", round(quantized[\"pearson\"]/default_acc[\"pearson\"],2)*100, \"% accuracy of the fp32 model\")"
"print(\"vanilla model: pearson=\", default_acc[\"pearson\"])\n",
"print(\"quantized model: pearson=\", quantized[\"pearson\"])\n",
"print(\n",
" \"The quantized model achieves \",\n",
" round(quantized[\"pearson\"] / default_acc[\"pearson\"], 2) * 100,\n",
" \"% accuracy of the fp32 model\",\n",
")"
]
},
{
Expand Down

0 comments on commit b6220d5

Please sign in to comment.