From b178d75bbc54d5a96e3d55bbbdb5b2c3e60bbaa1 Mon Sep 17 00:00:00 2001 From: Zac Li Date: Mon, 15 Jan 2024 18:15:22 +0800 Subject: [PATCH 1/6] feat: executor to gcp custom container --- jina/enums.py | 1 + jina/orchestrate/deployments/__init__.py | 10 +- jina/orchestrate/flow/base.py | 12 +- jina/serve/runtimes/servers/http.py | 37 ++++ jina/serve/runtimes/worker/http_gcp_app.py | 192 ++++++++++++++++++ .../serve/runtimes/worker/request_handling.py | 22 ++ tests/integration/docarray_v2/gcp/Dockerfile | 7 + .../docarray_v2/gcp/SampleExecutor/README.md | 2 + .../docarray_v2/gcp/SampleExecutor/config.yml | 8 + .../gcp/SampleExecutor/executor.py | 34 ++++ .../gcp/SampleExecutor/requirements.txt | 0 tests/integration/docarray_v2/gcp/__init__.py | 0 tests/integration/docarray_v2/gcp/test_gcp.py | 74 +++++++ 13 files changed, 388 insertions(+), 11 deletions(-) create mode 100644 jina/serve/runtimes/worker/http_gcp_app.py create mode 100644 tests/integration/docarray_v2/gcp/Dockerfile create mode 100644 tests/integration/docarray_v2/gcp/SampleExecutor/README.md create mode 100644 tests/integration/docarray_v2/gcp/SampleExecutor/config.yml create mode 100644 tests/integration/docarray_v2/gcp/SampleExecutor/executor.py create mode 100644 tests/integration/docarray_v2/gcp/SampleExecutor/requirements.txt create mode 100644 tests/integration/docarray_v2/gcp/__init__.py create mode 100644 tests/integration/docarray_v2/gcp/test_gcp.py diff --git a/jina/enums.py b/jina/enums.py index f85d26bdc8db0..9bf8cae100ffe 100644 --- a/jina/enums.py +++ b/jina/enums.py @@ -269,6 +269,7 @@ class ProviderType(BetterEnum): NONE = 0 #: no provider SAGEMAKER = 1 #: AWS SageMaker + GCP = 2 #: GCP def replace_enum_to_str(obj): diff --git a/jina/orchestrate/deployments/__init__.py b/jina/orchestrate/deployments/__init__.py index 580370fd6373f..fefeb40313f23 100644 --- a/jina/orchestrate/deployments/__init__.py +++ b/jina/orchestrate/deployments/__init__.py @@ -386,7 +386,7 @@ def __init__( :param port_monitoring: The port on which the prometheus server is exposed, default is a random port between [49152, 65535] :param prefer_platform: The preferred target Docker platform. (e.g. "linux/amd64", "linux/arm64") :param protocol: Communication protocol of the server exposed by the Executor. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. - :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. + :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER', 'GCP']. :param py_modules: The customized python modules need to be imported before loading the executor Note that the recommended way is to only import a single module - a simple python file, if your @@ -476,21 +476,21 @@ def __init__( args = ArgNamespace.kwargs2namespace(kwargs, parser, True) self.args = args self._gateway_load_balancer = False - if self.args.provider == ProviderType.SAGEMAKER: + if self.args.provider in (ProviderType.SAGEMAKER, ProviderType.GCP): if self._gateway_kwargs.get('port', 0) == 8080: raise ValueError( - 'Port 8080 is reserved for Sagemaker deployment. ' + 'Port 8080 is reserved for CSP deployment. ' 'Please use another port' ) if self.args.port != [8080]: warnings.warn( - 'Port is changed to 8080 for Sagemaker deployment. ' + 'Port is changed to 8080 for CSP deployment. ' f'Port {self.args.port} is ignored' ) self.args.port = [8080] if self.args.protocol != [ProtocolType.HTTP]: warnings.warn( - 'Protocol is changed to HTTP for Sagemaker deployment. ' + 'Protocol is changed to HTTP for CSP deployment. ' f'Protocol {self.args.protocol} is ignored' ) self.args.protocol = [ProtocolType.HTTP] diff --git a/jina/orchestrate/flow/base.py b/jina/orchestrate/flow/base.py index 3f93448c0015a..1364697181503 100644 --- a/jina/orchestrate/flow/base.py +++ b/jina/orchestrate/flow/base.py @@ -273,7 +273,7 @@ def __init__( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol of the server exposed by the Gateway. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. - :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. + :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER', 'GCP']. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy :param py_modules: The customized python modules need to be imported before loading the gateway @@ -464,7 +464,7 @@ def __init__( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol of the server exposed by the Gateway. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. - :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. + :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER', 'GCP']. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy :param py_modules: The customized python modules need to be imported before loading the gateway @@ -969,7 +969,7 @@ def add( :param port_monitoring: The port on which the prometheus server is exposed, default is a random port between [49152, 65535] :param prefer_platform: The preferred target Docker platform. (e.g. "linux/amd64", "linux/arm64") :param protocol: Communication protocol of the server exposed by the Executor. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. - :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. + :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER', 'GCP']. :param py_modules: The customized python modules need to be imported before loading the executor Note that the recommended way is to only import a single module - a simple python file, if your @@ -1132,7 +1132,7 @@ def add( :param port_monitoring: The port on which the prometheus server is exposed, default is a random port between [49152, 65535] :param prefer_platform: The preferred target Docker platform. (e.g. "linux/amd64", "linux/arm64") :param protocol: Communication protocol of the server exposed by the Executor. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. - :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. + :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER', 'GCP']. :param py_modules: The customized python modules need to be imported before loading the executor Note that the recommended way is to only import a single module - a simple python file, if your @@ -1396,7 +1396,7 @@ def config_gateway( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol of the server exposed by the Gateway. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. - :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. + :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER', 'GCP']. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy :param py_modules: The customized python modules need to be imported before loading the gateway @@ -1496,7 +1496,7 @@ def config_gateway( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol of the server exposed by the Gateway. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. - :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. + :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER', 'GCP']. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy :param py_modules: The customized python modules need to be imported before loading the gateway diff --git a/jina/serve/runtimes/servers/http.py b/jina/serve/runtimes/servers/http.py index 4ca3685c735ec..9cf1ca663d6a0 100644 --- a/jina/serve/runtimes/servers/http.py +++ b/jina/serve/runtimes/servers/http.py @@ -297,3 +297,40 @@ def app(self): cors=self.cors, logger=self.logger, ) + + +class GCPHTTPServer(FastAPIBaseServer): + """ + :class:`GCPHTTPServer` is a FastAPIBaseServer that uses a custom FastAPI app for GCP endpoints + + """ + + @property + def port(self): + """Get the port for the GCP server + :return: Return the port for the GCP server, always 8080""" + return 8080 + + @property + def ports(self): + """Get the port for the GCP server + :return: Return the port for the GCP server, always 8080""" + return [8080] + + @property + def app(self): + """Get the GCP fastapi app + :return: Return a FastAPI app for the GCP container + """ + return self._request_handler._http_fastapi_gcp_app( + title=self.title, + description=self.description, + no_crud_endpoints=self.no_crud_endpoints, + no_debug_endpoints=self.no_debug_endpoints, + expose_endpoints=self.expose_endpoints, + expose_graphql_endpoint=self.expose_graphql_endpoint, + tracing=self.tracing, + tracer_provider=self.tracer_provider, + cors=self.cors, + logger=self.logger, + ) diff --git a/jina/serve/runtimes/worker/http_gcp_app.py b/jina/serve/runtimes/worker/http_gcp_app.py new file mode 100644 index 0000000000000..a2f403525679b --- /dev/null +++ b/jina/serve/runtimes/worker/http_gcp_app.py @@ -0,0 +1,192 @@ +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union, Any + +from jina._docarray import docarray_v2 +from jina.importer import ImportExtensions +from jina.types.request.data import DataRequest + +if TYPE_CHECKING: + from jina.logging.logger import JinaLogger + +if docarray_v2: + from docarray import BaseDoc, DocList + + +def get_fastapi_app( + request_models_map: Dict, + caller: Callable, + logger: 'JinaLogger', + cors: bool = False, + **kwargs, +): + """ + Get the app from FastAPI as the REST interface. + + :param request_models_map: Map describing the endpoints and its Pydantic models + :param caller: Callable to be handled by the endpoints of the returned FastAPI app + :param logger: Logger object + :param cors: If set, a CORS middleware is added to FastAPI frontend to allow cross-origin access. + :param kwargs: Extra kwargs to make it compatible with other methods + :return: fastapi app + """ + with ImportExtensions(required=True): + import pydantic + from fastapi import FastAPI, HTTPException, Request + from fastapi.middleware.cors import CORSMiddleware + from pydantic import BaseModel, Field + from pydantic.config import BaseConfig, inherit_config + + import os + + from jina.proto import jina_pb2 + from jina.serve.runtimes.gateway.models import _to_camel_case + + if not docarray_v2: + logger.warning('Only docarray v2 is supported with Sagemaker. ') + return + + class Header(BaseModel): + request_id: Optional[str] = Field( + description='Request ID', example=os.urandom(16).hex() + ) + + class Config(BaseConfig): + alias_generator = _to_camel_case + allow_population_by_field_name = True + + class InnerConfig(BaseConfig): + alias_generator = _to_camel_case + allow_population_by_field_name = True + + class VertexAIResponse(BaseModel): + predictions: Any = Field( + description='Prediction results', + ) + + app = FastAPI() + + if cors: + app.add_middleware( + CORSMiddleware, + allow_origins=['*'], + allow_credentials=True, + allow_methods=['*'], + allow_headers=['*'], + ) + logger.warning('CORS is enabled. This service is accessible from any website!') + + def add_post_route( + endpoint_path, + input_model, + output_model, + input_doc_list_model=None, + ): + from docarray.base_doc.docarray_response import DocArrayResponse + + app_kwargs = dict( + path=f'/{endpoint_path.strip("/")}', + methods=['POST'], + summary=f'Endpoint {endpoint_path}', + response_model=Union[output_model, List[output_model]], + response_class=DocArrayResponse, + ) + + def is_valid_csv(content: str) -> bool: + import csv + from io import StringIO + + try: + f = StringIO(content) + reader = csv.DictReader(f) + for _ in reader: + pass + + return True + except Exception: + return False + + async def process(body) -> output_model: + req = DataRequest() + if body.header is not None: + req.header.request_id = body.header.request_id + + if body.parameters is not None: + req.parameters = body.parameters + req.header.exec_endpoint = endpoint_path + req.document_array_cls = DocList[input_doc_model] + + data = body.data + if isinstance(data, list): + req.data.docs = DocList[input_doc_list_model](data) + else: + req.data.docs = DocList[input_doc_list_model]([data]) + if body.header is None: + req.header.request_id = req.docs[0].id + + resp = await caller(req) + status = resp.header.status + + if status.code == jina_pb2.StatusProto.ERROR: + raise HTTPException(status_code=499, detail=status.description) + else: + return {"predictions": resp.docs} + return output_model(predictions=resp.docs) + + @app.api_route(**app_kwargs) + async def post(request: Request): + content_type = request.headers.get('content-type') + if content_type == 'application/json': + json_body = await request.json() + transformed_json_body = {"data": [{"text": instance} for instance in json_body["instances"]]} + return await process(input_model(**transformed_json_body)) + + elif content_type in ('text/csv', 'application/csv'): + # TODO: fix here + return await process(input_model(data=[])) + else: + raise HTTPException( + status_code=400, + detail=f'Invalid content-type: {content_type}. ' + f'Please use either application/json or text/csv.', + ) + + for endpoint, input_output_map in request_models_map.items(): + if endpoint != '_jina_dry_run_': + input_doc_model = input_output_map['input']['model'] + parameters_model = input_output_map['parameters']['model'] or Optional[Dict] + default_parameters = ( + ... if input_output_map['parameters']['model'] else None + ) + + _config = inherit_config(InnerConfig, BaseDoc.__config__) + endpoint_input_model = pydantic.create_model( + f'{endpoint.strip("/")}_input_model', + data=(Union[List[input_doc_model], input_doc_model], ...), + parameters=(parameters_model, default_parameters), + header=(Optional[Header], None), + __config__=_config, + ) + + add_post_route( + endpoint, + input_model=endpoint_input_model, + output_model=VertexAIResponse, + input_doc_list_model=input_doc_model, + ) + + from jina.serve.runtimes.gateway.health_model import JinaHealthModel + + # `/ping` route is required by AWS Sagemaker + @app.get( + path='/ping', + summary='Get the health of Jina Executor service', + response_model=JinaHealthModel, + ) + async def _executor_health(): + """ + Get the health of this Gateway service. + .. # noqa: DAR201 + + """ + return {} + + return app diff --git a/jina/serve/runtimes/worker/request_handling.py b/jina/serve/runtimes/worker/request_handling.py index 0849aaebb388d..57f26a6767da2 100644 --- a/jina/serve/runtimes/worker/request_handling.py +++ b/jina/serve/runtimes/worker/request_handling.py @@ -212,6 +212,28 @@ async def _shutdown(): return app + def _http_fastapi_gcp_app(self, **kwargs): + from jina.serve.runtimes.worker.http_gcp_app import get_fastapi_app + + request_models_map = self._executor._get_endpoint_models_dict() + + def call_handle(request): + is_generator = request_models_map[request.header.exec_endpoint][ + 'is_generator' + ] + + return self.process_single_data(request, None, is_generator=is_generator) + + app = get_fastapi_app( + request_models_map=request_models_map, caller=call_handle, **kwargs + ) + + @app.on_event('shutdown') + async def _shutdown(): + await self.close() + + return app + async def _hot_reload(self): import inspect diff --git a/tests/integration/docarray_v2/gcp/Dockerfile b/tests/integration/docarray_v2/gcp/Dockerfile new file mode 100644 index 0000000000000..b04f20b347f45 --- /dev/null +++ b/tests/integration/docarray_v2/gcp/Dockerfile @@ -0,0 +1,7 @@ +FROM jinaai/jina:test-pip + +COPY . /executor_root/ + +WORKDIR /executor_root/SampleExecutor + +ENTRYPOINT ["jina", "executor", "--uses", "config.yml"] diff --git a/tests/integration/docarray_v2/gcp/SampleExecutor/README.md b/tests/integration/docarray_v2/gcp/SampleExecutor/README.md new file mode 100644 index 0000000000000..49da1225f4487 --- /dev/null +++ b/tests/integration/docarray_v2/gcp/SampleExecutor/README.md @@ -0,0 +1,2 @@ +# SampleExecutor + diff --git a/tests/integration/docarray_v2/gcp/SampleExecutor/config.yml b/tests/integration/docarray_v2/gcp/SampleExecutor/config.yml new file mode 100644 index 0000000000000..6b819858f2fc8 --- /dev/null +++ b/tests/integration/docarray_v2/gcp/SampleExecutor/config.yml @@ -0,0 +1,8 @@ +jtype: SampleExecutor +py_modules: + - executor.py +metas: + name: SampleExecutor + description: + url: + keywords: [] \ No newline at end of file diff --git a/tests/integration/docarray_v2/gcp/SampleExecutor/executor.py b/tests/integration/docarray_v2/gcp/SampleExecutor/executor.py new file mode 100644 index 0000000000000..1e0b4afc129c2 --- /dev/null +++ b/tests/integration/docarray_v2/gcp/SampleExecutor/executor.py @@ -0,0 +1,34 @@ +import numpy as np +from docarray import BaseDoc, DocList +from docarray.typing import NdArray +from pydantic import Field + +from jina import Executor, requests + + +class TextDoc(BaseDoc): + text: str = Field(description="The text of the document", default="") + + +class EmbeddingResponseModel(TextDoc): + embeddings: NdArray = Field(description="The embedding of the texts", default=[]) + + class Config(BaseDoc.Config): + allow_population_by_field_name = True + arbitrary_types_allowed = True + json_encoders = {NdArray: lambda v: v.tolist()} + + +class SampleExecutor(Executor): + @requests(on="/encode") + def foo(self, docs: DocList[TextDoc], **kwargs) -> DocList[EmbeddingResponseModel]: + ret = [] + for doc in docs: + ret.append( + EmbeddingResponseModel( + id=doc.id, + text=doc.text, + embeddings=np.random.random((1, 64)), + ) + ) + return DocList[EmbeddingResponseModel](ret) diff --git a/tests/integration/docarray_v2/gcp/SampleExecutor/requirements.txt b/tests/integration/docarray_v2/gcp/SampleExecutor/requirements.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/integration/docarray_v2/gcp/__init__.py b/tests/integration/docarray_v2/gcp/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/integration/docarray_v2/gcp/test_gcp.py b/tests/integration/docarray_v2/gcp/test_gcp.py new file mode 100644 index 0000000000000..2fc97a2adffda --- /dev/null +++ b/tests/integration/docarray_v2/gcp/test_gcp.py @@ -0,0 +1,74 @@ +import csv +import io +import os +import time +from contextlib import AbstractContextManager + +import pytest +import requests + +from jina import Deployment +from jina.helper import random_port +from jina.orchestrate.pods import Pod +from jina.parsers import set_pod_parser + +cur_dir = os.path.dirname(os.path.abspath(__file__)) +gcp_port = 8080 + + +@pytest.fixture +def replica_docker_image_built(): + import docker + + client = docker.from_env() + client.images.build(path=cur_dir, tag='sampler-executor') + client.close() + yield + time.sleep(2) + client = docker.from_env() + client.containers.prune() + + +class chdir(AbstractContextManager): + def __init__(self, path): + self.path = path + self._old_cwd = [] + + def __enter__(self): + self._old_cwd.append(os.getcwd()) + os.chdir(self.path) + + def __exit__(self, *excinfo): + os.chdir(self._old_cwd.pop()) + + +def test_provider_gcp_pod_inference(): + with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')): + args, _ = set_pod_parser().parse_known_args( + [ + '--uses', + 'config.yml', + '--provider', + 'gcp', + 'serve', # This is added by gcp + ] + ) + with Pod(args): + # Test the `GET /ping` endpoint (added by jina for gcp) + resp = requests.get(f'http://localhost:{gcp_port}/ping') + assert resp.status_code == 200 + assert resp.json() == {} + + # Test the `POST /invocations` endpoint for inference + # Note: this endpoint is not implemented in the sample executor + resp = requests.post( + f'http://localhost:{gcp_port}/invocations', + json={ + 'instances': ["hello world", "good apple"] + }, + ) + assert resp.status_code == 200 + resp_json = resp.json() + assert len(resp_json['predictions']) == 2 + print(resp_json) + From a3dfc9c4d295ff5a967823f9b1c73d7552418ef6 Mon Sep 17 00:00:00 2001 From: Jina Dev Bot Date: Mon, 15 Jan 2024 10:17:38 +0000 Subject: [PATCH 2/6] style: fix overload and cli autocomplete --- jina/serve/executors/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jina/serve/executors/__init__.py b/jina/serve/executors/__init__.py index 8daa6616110e3..8ccf4cab76e36 100644 --- a/jina/serve/executors/__init__.py +++ b/jina/serve/executors/__init__.py @@ -1084,7 +1084,7 @@ def serve( :param port_monitoring: The port on which the prometheus server is exposed, default is a random port between [49152, 65535] :param prefer_platform: The preferred target Docker platform. (e.g. "linux/amd64", "linux/arm64") :param protocol: Communication protocol of the server exposed by the Executor. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. - :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. + :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER', 'GCP']. :param py_modules: The customized python modules need to be imported before loading the executor Note that the recommended way is to only import a single module - a simple python file, if your From 8a48521b89432a9447edf99b38c1ba04b3bed72a Mon Sep 17 00:00:00 2001 From: Zac Li Date: Tue, 16 Jan 2024 11:38:32 +0800 Subject: [PATCH 3/6] fix: issue in gcp app --- jina/serve/executors/__init__.py | 8 +++---- jina/serve/runtimes/asyncio.py | 17 ++++++++++++++ jina/serve/runtimes/worker/http_gcp_app.py | 5 ++--- .../serve/runtimes/worker/request_handling.py | 2 +- tests/integration/docarray_v2/gcp/test_gcp.py | 22 ++++++++++++++++++- 5 files changed, 45 insertions(+), 9 deletions(-) diff --git a/jina/serve/executors/__init__.py b/jina/serve/executors/__init__.py index 8ccf4cab76e36..d438746a4ab2a 100644 --- a/jina/serve/executors/__init__.py +++ b/jina/serve/executors/__init__.py @@ -393,7 +393,7 @@ def __init__( self._add_dynamic_batching(dynamic_batching) self._add_runtime_args(runtime_args) self.logger = JinaLogger(self.__class__.__name__, **vars(self.runtime_args)) - self._validate_sagemaker() + self._validate_csp() self._init_instrumentation(runtime_args) self._init_monitoring() self._init_workspace = workspace @@ -599,14 +599,14 @@ def _add_requests(self, _requests: Optional[Dict]): f'expect {typename(self)}.{func} to be a function, but receiving {typename(_func)}' ) - def _validate_sagemaker(self): - # sagemaker expects the POST /invocations endpoint to be defined. + def _validate_csp(self): + # csp (sagemaker/azure/gcp) expects the POST /invocations endpoint to be defined. # if it is not defined, we check if there is only one endpoint defined, # and if so, we use it as the POST /invocations endpoint, or raise an error if ( not hasattr(self, 'runtime_args') or not hasattr(self.runtime_args, 'provider') - or self.runtime_args.provider != ProviderType.SAGEMAKER.value + or self.runtime_args.provider not in (ProviderType.SAGEMAKER.value, ProviderType.GCP.value) ): return diff --git a/jina/serve/runtimes/asyncio.py b/jina/serve/runtimes/asyncio.py index 8d2fc8beeb8bc..e53f55a5c124d 100644 --- a/jina/serve/runtimes/asyncio.py +++ b/jina/serve/runtimes/asyncio.py @@ -206,6 +206,23 @@ def _get_server(self): cors=getattr(self.args, 'cors', None), is_cancel=self.is_cancel, ) + elif ( + hasattr(self.args, 'provider') + and self.args.provider == ProviderType.GCP + ): + from jina.serve.runtimes.servers.http import GCPHTTPServer + + return GCPHTTPServer( + name=self.args.name, + runtime_args=self.args, + req_handler_cls=self.req_handler_cls, + proxy=getattr(self.args, 'proxy', None), + uvicorn_kwargs=getattr(self.args, 'uvicorn_kwargs', None), + ssl_keyfile=getattr(self.args, 'ssl_keyfile', None), + ssl_certfile=getattr(self.args, 'ssl_certfile', None), + cors=getattr(self.args, 'cors', None), + is_cancel=self.is_cancel, + ) elif not hasattr(self.args, 'protocol') or ( len(self.args.protocol) == 1 and self.args.protocol[0] == ProtocolType.GRPC ): diff --git a/jina/serve/runtimes/worker/http_gcp_app.py b/jina/serve/runtimes/worker/http_gcp_app.py index a2f403525679b..37af7897a029c 100644 --- a/jina/serve/runtimes/worker/http_gcp_app.py +++ b/jina/serve/runtimes/worker/http_gcp_app.py @@ -41,7 +41,7 @@ def get_fastapi_app( from jina.serve.runtimes.gateway.models import _to_camel_case if not docarray_v2: - logger.warning('Only docarray v2 is supported with Sagemaker. ') + logger.warning('Only docarray v2 is supported with GCP. ') return class Header(BaseModel): @@ -129,7 +129,6 @@ async def process(body) -> output_model: raise HTTPException(status_code=499, detail=status.description) else: return {"predictions": resp.docs} - return output_model(predictions=resp.docs) @app.api_route(**app_kwargs) async def post(request: Request): @@ -175,7 +174,7 @@ async def post(request: Request): from jina.serve.runtimes.gateway.health_model import JinaHealthModel - # `/ping` route is required by AWS Sagemaker + # `/ping` route is required by GCP @app.get( path='/ping', summary='Get the health of Jina Executor service', diff --git a/jina/serve/runtimes/worker/request_handling.py b/jina/serve/runtimes/worker/request_handling.py index 57f26a6767da2..f5c74445bf291 100644 --- a/jina/serve/runtimes/worker/request_handling.py +++ b/jina/serve/runtimes/worker/request_handling.py @@ -326,7 +326,7 @@ def _init_monitoring( if metrics_registry: with ImportExtensions( required=True, - help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina', + help_text='You need to install the `prometheus_client` to use the monitoring functionality of jina', ): from prometheus_client import Counter, Summary diff --git a/tests/integration/docarray_v2/gcp/test_gcp.py b/tests/integration/docarray_v2/gcp/test_gcp.py index 2fc97a2adffda..9539e8317ea49 100644 --- a/tests/integration/docarray_v2/gcp/test_gcp.py +++ b/tests/integration/docarray_v2/gcp/test_gcp.py @@ -70,5 +70,25 @@ def test_provider_gcp_pod_inference(): assert resp.status_code == 200 resp_json = resp.json() assert len(resp_json['predictions']) == 2 - print(resp_json) + +def test_provider_gcp_deployment_inference(): + with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')): + dep_port = random_port() + with Deployment(uses='config.yml', provider='gcp', port=dep_port): + # Test the `GET /ping` endpoint (added by jina for gcp) + resp = requests.get(f'http://localhost:{dep_port}/ping') + assert resp.status_code == 200 + assert resp.json() == {} + + # Test the `POST /invocations` endpoint + # Note: this endpoint is not implemented in the sample executor + resp = requests.post( + f'http://localhost:{dep_port}/invocations', + json={ + 'instances': ["hello world", "good apple"] + }, + ) + assert resp.status_code == 200 + resp_json = resp.json() + assert len(resp_json['predictions']) == 2 From 17d5f90bf8ef8b88d6465207f2033ec3c48c06d4 Mon Sep 17 00:00:00 2001 From: Zac Li Date: Wed, 17 Jan 2024 13:54:26 +0800 Subject: [PATCH 4/6] fix: issue in gcp app --- jina/serve/runtimes/worker/http_gcp_app.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/jina/serve/runtimes/worker/http_gcp_app.py b/jina/serve/runtimes/worker/http_gcp_app.py index 37af7897a029c..2fbe68fd67692 100644 --- a/jina/serve/runtimes/worker/http_gcp_app.py +++ b/jina/serve/runtimes/worker/http_gcp_app.py @@ -79,6 +79,7 @@ def add_post_route( input_model, output_model, input_doc_list_model=None, + output_doc_list_model=None, ): from docarray.base_doc.docarray_response import DocArrayResponse @@ -128,7 +129,7 @@ async def process(body) -> output_model: if status.code == jina_pb2.StatusProto.ERROR: raise HTTPException(status_code=499, detail=status.description) else: - return {"predictions": resp.docs} + return VertexAIResponse(predictions=output_model(data=resp.docs, parameters=resp.parameters)) @app.api_route(**app_kwargs) async def post(request: Request): @@ -151,6 +152,7 @@ async def post(request: Request): for endpoint, input_output_map in request_models_map.items(): if endpoint != '_jina_dry_run_': input_doc_model = input_output_map['input']['model'] + output_doc_model = input_output_map['output']['model'] parameters_model = input_output_map['parameters']['model'] or Optional[Dict] default_parameters = ( ... if input_output_map['parameters']['model'] else None @@ -165,11 +167,19 @@ async def post(request: Request): __config__=_config, ) + endpoint_output_model = pydantic.create_model( + f'{endpoint.strip("/")}_output_model', + data=(Union[List[output_doc_model], output_doc_model], ...), + parameters=(Optional[Dict], None), + __config__=_config, + ) + add_post_route( endpoint, input_model=endpoint_input_model, - output_model=VertexAIResponse, + output_model=endpoint_output_model, input_doc_list_model=input_doc_model, + output_doc_list_model=VertexAIResponse, ) from jina.serve.runtimes.gateway.health_model import JinaHealthModel From 84a466715044e401c753c21486f85001bc52cb97 Mon Sep 17 00:00:00 2001 From: Jina Dev Bot Date: Wed, 17 Jan 2024 05:57:57 +0000 Subject: [PATCH 5/6] style: fix overload and cli autocomplete --- jina/serve/executors/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jina/serve/executors/__init__.py b/jina/serve/executors/__init__.py index d438746a4ab2a..24ed4c32390b6 100644 --- a/jina/serve/executors/__init__.py +++ b/jina/serve/executors/__init__.py @@ -606,7 +606,8 @@ def _validate_csp(self): if ( not hasattr(self, 'runtime_args') or not hasattr(self.runtime_args, 'provider') - or self.runtime_args.provider not in (ProviderType.SAGEMAKER.value, ProviderType.GCP.value) + or self.runtime_args.provider + not in (ProviderType.SAGEMAKER.value, ProviderType.GCP.value) ): return From b258808cffbb83e764f3fc4e5eb609c2ab16a39d Mon Sep 17 00:00:00 2001 From: Zac Li Date: Thu, 1 Feb 2024 14:53:42 +0800 Subject: [PATCH 6/6] fix: issue in gcp app --- jina/serve/runtimes/worker/http_gcp_app.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/jina/serve/runtimes/worker/http_gcp_app.py b/jina/serve/runtimes/worker/http_gcp_app.py index 2fbe68fd67692..c804acda8d4a4 100644 --- a/jina/serve/runtimes/worker/http_gcp_app.py +++ b/jina/serve/runtimes/worker/http_gcp_app.py @@ -79,7 +79,6 @@ def add_post_route( input_model, output_model, input_doc_list_model=None, - output_doc_list_model=None, ): from docarray.base_doc.docarray_response import DocArrayResponse @@ -87,7 +86,7 @@ def add_post_route( path=f'/{endpoint_path.strip("/")}', methods=['POST'], summary=f'Endpoint {endpoint_path}', - response_model=Union[output_model, List[output_model]], + response_model=VertexAIResponse, response_class=DocArrayResponse, ) @@ -129,7 +128,7 @@ async def process(body) -> output_model: if status.code == jina_pb2.StatusProto.ERROR: raise HTTPException(status_code=499, detail=status.description) else: - return VertexAIResponse(predictions=output_model(data=resp.docs, parameters=resp.parameters)) + return VertexAIResponse(predictions=req.data.docs) @app.api_route(**app_kwargs) async def post(request: Request): @@ -140,7 +139,7 @@ async def post(request: Request): return await process(input_model(**transformed_json_body)) elif content_type in ('text/csv', 'application/csv'): - # TODO: fix here + # TODO: fix here for batch transform return await process(input_model(data=[])) else: raise HTTPException( @@ -179,7 +178,6 @@ async def post(request: Request): input_model=endpoint_input_model, output_model=endpoint_output_model, input_doc_list_model=input_doc_model, - output_doc_list_model=VertexAIResponse, ) from jina.serve.runtimes.gateway.health_model import JinaHealthModel