Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update A_B.ipynb #161

Open
wants to merge 3 commits into
base: dmarx.ab_notebook
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 103 additions & 89 deletions nbs/A_B.ipynb
Original file line number Diff line number Diff line change
@@ -1,21 +1,10 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "i0RWQRJAkdJe"
},
"source": [
"# Stability.AI A/B Testing Notebook\n",
"\n",
Expand All @@ -38,10 +27,7 @@
" - Click again to deselect if you cahnge your mind\n",
" - The notebook does not currently constrain the user to only select one option, but that's how we recommend you use it. \n",
" - When you're satisfied with your selection, execute the cell again to log your feedback and generate a new set of images."
],
"metadata": {
"id": "i0RWQRJAkdJe"
}
]
},
{
"cell_type": "code",
Expand Down Expand Up @@ -242,12 +228,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GX85BLyFrGlJ"
},
"outputs": [],
"source": [
"%%writefile test_config.yaml\n",
"\n",
"### settings that will be used across test cases.\n",
"defaults:\n",
" grpc_host: grpc.stability.ai:443\n",
" host: grpc.stability.ai:443\n",
" # If API key not provided in test_config.yaml, user prompted with getpass\n",
" key:\n",
"\n",
Expand All @@ -270,6 +261,7 @@
"# not a fan of this name. maybe call this section \"experiments\"?\n",
"differentiators:\n",
" test_case_A:\n",
" host: grpc-staging.stability.ai:443\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty sure we don't want this endpoint to be the notebook default, since this notebook will be public and I'm pretty sure this endpoint isn't supposed to be. for cases where we want to test against this endpoint, we can distribute an appropriate config file that has it as a test case like this

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. Very thoughtful. Thanks!

" engine: stable-diffusion-512-v2-0\n",
" prompt_chunks:\n",
" middle: ''\n",
Expand All @@ -285,25 +277,29 @@
" middle: ''\n",
" # Don't do this, results in `middle:\"None\"`\n",
" # middle:\n"
],
"metadata": {
"id": "GX85BLyFrGlJ"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "xl_t3CjZlWvq"
},
"outputs": [],
"source": [
"# @markdown ## Load Experiments\n",
"\n",
"from omegaconf import OmegaConf\n",
"import getpass\n",
"from stability_sdk import client\n",
"\n",
"import shutil\n",
"import panel as pn\n",
"pn.extension()\n",
"\n",
"!mkdir fav\n",
"!mkdir results\n",
"\n",
"\n",
"exp_cfg_fpath_out = Path(workspace_cfg.project_root) / workspace_cfg.exp_cfg_fname\n",
"exp_cfg_fpath = exp_cfg_fpath_out\n",
Expand Down Expand Up @@ -342,7 +338,7 @@
"#####################################\n",
"\n",
"required_attributes = [\n",
" 'grpc_host',\n",
" 'host',\n",
" #'api_key'\n",
" 'key',\n",
"]\n",
Expand Down Expand Up @@ -379,76 +375,35 @@
"for test_case in cfg.differentiators:\n",
" running_score[test_case]+=0\n",
"\n"
],
"metadata": {
"id": "xl_t3CjZlWvq",
"cellView": "form"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "ogPWQhWm7BXN"
},
"outputs": [],
"source": [
"# @markdown # Load a random sample to score preference\n",
"\n",
"SAMPLE_IDX += 1\n",
"\n",
"Copy_to_path = \"/content/results/\"\n",
"##########################\n",
"# Log experiment outcome #\n",
"##########################\n",
"\n",
"save_images = False # @param {type:'boolean'}\n",
"save_favorite_only = False # @param {type:'boolean'}\n",
"save_favorite_only = True # @param {type:'boolean'}\n",
"\n",
"\n",
"\n",
"# to do: make this not a closure.\n",
"def log_items(items):\n",
" #for (img, test_case, kwargs_gen, is_preference) in items: # to do: dictify\n",
" recs = []\n",
" for item in items:\n",
" # assign image a filename\n",
" img_fname = f\"{RANDOM_PREFIX}_{SAMPLE_IDX}_{item['test_case']}.png\"\n",
" #rec = copy.deepcopy(item)\n",
" rec = item\n",
" img_fpath = Path(workspace_cfg.project_root) / img_fname\n",
" # save image\n",
" img = rec.pop('img')\n",
" save_im = False\n",
" if save_images or save_favorite_only:\n",
" save_im = True\n",
" if save_favorite_only and not rec['is_preference']:\n",
" save_im = False\n",
" if save_im:\n",
" print(img_fpath)\n",
" rec['img_fpath'] = str(img_fpath)\n",
" img.save(img_fpath)\n",
" # update outcome\n",
" rec['is_preference'] = rec['button'].value\n",
" if rec['is_preference']:\n",
" running_score[rec['test_case']] += 1\n",
" rec.pop('button')\n",
" # log outcome\n",
" recs.append(rec)\n",
" outfile = Path(workspace_cfg.project_root) / explog_fname\n",
" #with open(outfile, 'a') as f:\n",
" with outfile.open('a') as f:\n",
" json.dump(recs, f)\n",
" f.write('\\n')\n",
" logger.debug(running_score)\n",
" \n",
"if items:\n",
" try:\n",
" log_items(items)\n",
" posterior_plot(running_score)\n",
" except KeyError:\n",
" # fuck it\n",
" pass\n",
"\n",
"\n",
"SEED = random.randrange(0, 4294967295)\n",
"\n",
"blind_test = False # @param {type: \"boolean\"}\n",
"blind_test = True # @param {type: \"boolean\"}\n",
"\n",
"def item_to_ux(\n",
" item\n",
Expand All @@ -469,10 +424,11 @@
" output.append(toggle)\n",
" item['button'] = toggle\n",
" item['is_preference'] = toggle.value\n",
" print(item)\n",
" return pn.Column(*output)\n",
"\n",
"\n",
"non_generation_arguments = ['grpc_host', 'engine', 'key']\n",
"non_generation_arguments = ['host', 'engine', 'key']\n",
"\n",
"rec = random.choice(experiments)\n",
"\n",
Expand Down Expand Up @@ -508,6 +464,10 @@
" if artifact.type == generation.ARTIFACT_IMAGE:\n",
" img = Image.open(io.BytesIO(artifact.binary))\n",
" img = img.resize([512, 512])\n",
" if save_images or save_favorite_only:\n",
" img_fname = f\"{kwargs_gen['prompt']}_{kwargs_gen['seed']}_{SAMPLE_IDX}_{test_case}.png\"\n",
" img.save(Copy_to_path+img_fname)\n",
"\n",
" items.append({\n",
" 'img':img,\n",
" 'test_case':test_case,\n",
Expand All @@ -517,18 +477,72 @@
" 'SDK_VERSION':SDK_VERSION,\n",
" 'timestamp':time.time(),\n",
" 'user_id': workspace_cfg.notebook_user_id,\n",
" 'project_name':workspace_cfg.active_project,\n",
" 'project_name':workspace_cfg.active_project, \n",
" })\n",
"\n",
"random.shuffle(items)\n",
"pn.Row(*[item_to_ux(it) for it in items])"
],
"metadata": {
"id": "ogPWQhWm7BXN",
"cellView": "form"
},
"pn.Row(*[item_to_ux(it) for it in items]) "
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": []
"metadata": {},
"outputs": [],
"source": [
"# to do: make this not a closure.\n",
"import shutil\n",
"def log_items(items):\n",
" #for (img, test_case, kwargs_gen, is_preference) in items: # to do: dictify\n",
" recs = []\n",
" for item in items:\n",
" # assign image a filename\n",
" # rec = copy.deepcopy(item)\n",
" rec = item\n",
" # save image\n",
" rec['is_preference'] = rec['button'].value\n",
" if save_favorite_only:\n",
" if rec['is_preference']:\n",
" img_fname1 = f\"{rec['kwargs']['prompt']}_{rec['kwargs']['seed']}_{SAMPLE_IDX}_{rec['test_case']}.png\"\n",
" print(type(Copy_to_path+img_fname1))\n",
" shutil.move(Copy_to_path+img_fname1, '/content/fav')\n",
" # update outcome\n",
" if rec['is_preference']:\n",
" running_score[rec['test_case']] += 1\n",
" rec.pop('img')\n",
" rec.pop('button')\n",
" # log outcome\n",
" recs.append(rec)\n",
" outfile = Path(workspace_cfg.project_root) / explog_fname\n",
" #with open(outfile, 'a') as f:\n",
" with outfile.open('a') as f:\n",
" json.dump(recs, f)\n",
" f.write('\\n')\n",
" logger.debug(running_score)\n",
"\n",
"# print(items)\n",
"if items:\n",
" try:\n",
" log_items(items)\n",
" posterior_plot(running_score)\n",
" except KeyError:\n",
" # fuck it\n",
" pass"
]
}
]
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}