Skip to content

Commit

Permalink
1.4.6: optimize api speed
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed Dec 15, 2022
1 parent f6f21d3 commit d393957
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
23 changes: 12 additions & 11 deletions langame/functions/services.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from multiprocessing import Pool

from multiprocessing.pool import ThreadPool
import time
from typing import List, Optional, Tuple, Any
from langame.messages import (
Expand All @@ -9,8 +8,8 @@
)
import pytz
from random import choice
from firebase_admin import firestore, initialize_app
from google.cloud.firestore import DocumentSnapshot, Client
from firebase_admin import firestore
from google.cloud.firestore import DocumentSnapshot, Client, CollectionReference
from sentry_sdk import capture_exception
import logging
import datetime
Expand All @@ -21,6 +20,7 @@

def _generate(
i: int,
memes_collection_ref: CollectionReference,
api_key_doc_id: str,
logger: logging.Logger,
topics: List[str],
Expand All @@ -30,16 +30,16 @@ def _generate(
profanity_threshold: str,
translated: bool,
) -> Tuple[Optional[dict], Optional[dict], Optional[dict]]:
try:
initialize_app()
# try:
# initialize_app()
# pylint: disable=W0703
except: pass
db: Client = firestore.client()
# except: pass
# db: Client = firestore.client()
timeout = 60
start_time = time.time()
# format to human readable date time
logger.info(f"[{i}] Generating starter at {datetime.datetime.now(utc)}")
_, ref = db.collection("memes").add(
_, ref = memes_collection_ref.add(
{
"state": "to-process",
"topics": topics,
Expand All @@ -56,7 +56,7 @@ def _generate(

# poll until it's in state "processed" or "error", timeout after 1 minute
while True:
prompt_doc = db.collection("memes").document(ref.id).get()
prompt_doc = memes_collection_ref.document(ref.id).get()
data = prompt_doc.to_dict()
if data.get("state") == "processed" and data.get("content", None):
if translated and not data.get("translated", None):
Expand Down Expand Up @@ -168,13 +168,14 @@ def request_starter_for_service(
)

# generate in parallel for "limit"
with Pool(processes=limit) as pool:
with ThreadPool(processes=limit) as pool:

responses = pool.starmap(
_generate,
[
(
i,
db.collection("memes"),
api_key_doc.id,
logger,
topics,
Expand Down
1 change: 1 addition & 0 deletions run/collection/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ async def test_create_starter(self):
data = {
# pick 2 random topic
"topics": random.sample(fun_topics, 2),
"limit": 3,
}
print("querying with topics: ", data["topics"])
r = requests.post(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
name="langame",
packages=find_packages(),
include_package_data=True,
version="1.4.5",
version="1.4.6",
description="",
install_requires=[
"firebase_admin",
Expand Down

0 comments on commit d393957

Please sign in to comment.