Skip to content

Commit

Permalink
Merge branch 'main' into openai-legacy-response
Browse files Browse the repository at this point in the history
  • Loading branch information
lrafeei authored May 10, 2024
2 parents b2db1c2 + 3e5be52 commit b4e1de2
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 26 deletions.
5 changes: 5 additions & 0 deletions newrelic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4374,6 +4374,11 @@ def _process_module_builtin_defaults():
"newrelic.hooks.application_celery",
"instrument_celery_app_task",
)
_process_module_definition(
"celery.app.trace",
"newrelic.hooks.application_celery",
"instrument_celery_app_trace",
)
_process_module_definition("celery.worker", "newrelic.hooks.application_celery", "instrument_celery_worker")
_process_module_definition(
"celery.concurrency.processes",
Expand Down
51 changes: 41 additions & 10 deletions newrelic/hooks/application_celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,18 @@
from newrelic.api.message_trace import MessageTrace
from newrelic.api.pre_function import wrap_pre_function
from newrelic.api.transaction import current_transaction
from newrelic.common.object_wrapper import FunctionWrapper, wrap_function_wrapper
from newrelic.common.object_wrapper import FunctionWrapper, wrap_function_wrapper, _NRBoundFunctionWrapper
from newrelic.core.agent import shutdown_agent

UNKNOWN_TASK_NAME = "<Unknown Task>"
MAPPING_TASK_NAMES = {"celery.starmap", "celery.map"}


def task_name(*args, **kwargs):
def task_info(instance, *args, **kwargs):
# Grab the current task, which can be located in either place
if args:
if instance:
task = instance
elif args:
task = args[0]
elif "task" in kwargs:
task = kwargs["task"]
Expand All @@ -46,27 +48,27 @@ def task_name(*args, **kwargs):

# Task can be either a task instance or a signature, which subclasses dict, or an actual dict in some cases.
task_name = getattr(task, "name", None) or task.get("task", UNKNOWN_TASK_NAME)
task_source = task

# Under mapping tasks, the root task name isn't descriptive enough so we append the
# subtask name to differentiate between different mapping tasks
if task_name in MAPPING_TASK_NAMES:
try:
subtask = kwargs["task"]["task"]
task_name = "/".join((task_name, subtask))
task_source = task.app._tasks[subtask]
except Exception:
pass

return task_name
return task_name, task_source


def CeleryTaskWrapper(wrapped):
def wrapper(wrapped, instance, args, kwargs):
transaction = current_transaction(active_only=False)

if instance is not None:
_name = task_name(instance, *args, **kwargs)
else:
_name = task_name(*args, **kwargs)
# Grab task name and source
_name, _source = task_info(instance, *args, **kwargs)

# A Celery Task can be called either outside of a transaction, or
# within the context of an existing transaction. There are 3
Expand All @@ -93,11 +95,11 @@ def wrapper(wrapped, instance, args, kwargs):
return wrapped(*args, **kwargs)

elif transaction:
with FunctionTrace(_name, source=instance):
with FunctionTrace(_name, source=_source):
return wrapped(*args, **kwargs)

else:
with BackgroundTask(application_instance(), _name, "Celery", source=instance) as transaction:
with BackgroundTask(application_instance(), _name, "Celery", source=_source) as transaction:
# Attempt to grab distributed tracing headers
try:
# Headers on earlier versions of Celery may end up as attributes
Expand Down Expand Up @@ -200,6 +202,26 @@ def wrap_Celery_send_task(wrapped, instance, args, kwargs):
return wrapped(*args, **kwargs)


def wrap_worker_optimizations(wrapped, instance, args, kwargs):
# Attempt to uninstrument BaseTask before stack protection is installed or uninstalled
try:
from celery.app.task import BaseTask

if isinstance(BaseTask.__call__, _NRBoundFunctionWrapper):
BaseTask.__call__ = BaseTask.__call__.__wrapped__
except Exception:
BaseTask = None

# Allow metaprogramming to run
result = wrapped(*args, **kwargs)

# Rewrap finalized BaseTask
if BaseTask: # Ensure imports succeeded
BaseTask.__call__ = CeleryTaskWrapper(BaseTask.__call__)

return result


def instrument_celery_app_base(module):
if hasattr(module, "Celery") and hasattr(module.Celery, "send_task"):
wrap_function_wrapper(module, "Celery.send_task", wrap_Celery_send_task)
Expand Down Expand Up @@ -239,3 +261,12 @@ def force_agent_shutdown(*args, **kwargs):

if hasattr(module, "Worker"):
wrap_pre_function(module, "Worker._do_exit", force_agent_shutdown)


def instrument_celery_app_trace(module):
# Uses same wrapper for setup and reset worker optimizations to prevent patching and unpatching from removing wrappers
if hasattr(module, "setup_worker_optimizations"):
wrap_function_wrapper(module, "setup_worker_optimizations", wrap_worker_optimizations)

if hasattr(module, "reset_worker_optimizations"):
wrap_function_wrapper(module, "reset_worker_optimizations", wrap_worker_optimizations)
49 changes: 35 additions & 14 deletions tests/application_celery/test_task_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from _target_application import add, tsum
from celery import chain, chord, group
from testing_support.validators.validate_code_level_metrics import (
validate_code_level_metrics,
)
from testing_support.validators.validate_transaction_count import (
validate_transaction_count,
)
from testing_support.validators.validate_transaction_metrics import (
validate_transaction_metrics,
)

FORGONE_TASK_METRICS = [("Function/_target_application.add", None), ("Function/_target_application.tsum", None)]
import celery


def test_task_wrapping_detection():
"""
Ensure celery detects our monkeypatching properly and will run our instrumentation
on __call__ and runs that instead of micro-optimizing it away to a run() call.
FORGONE_TASK_METRICS = [("Function/_target_application.add", None), ("Function/_target_application.tsum", None)]

If this is not working, most other tests in this file will fail as the different ways
of running celery tasks will not all run our instrumentation.
"""
from celery.app.trace import task_has_custom

assert task_has_custom(add, "__call__")
@pytest.fixture(scope="module", autouse=True, params=[False, True], ids=["unpatched", "patched"])
def with_worker_optimizations(request, celery_worker_available):
if request.param:
celery.app.trace.setup_worker_optimizations(celery_worker_available.app)

yield request.param
celery.app.trace.reset_worker_optimizations()


@validate_transaction_metrics(
Expand All @@ -45,6 +47,7 @@ def test_task_wrapping_detection():
rollup_metrics=FORGONE_TASK_METRICS,
background_task=True,
)
@validate_code_level_metrics("_target_application", "add")
@validate_transaction_count(1)
def test_celery_task_call():
"""
Expand All @@ -61,6 +64,7 @@ def test_celery_task_call():
rollup_metrics=FORGONE_TASK_METRICS,
background_task=True,
)
@validate_code_level_metrics("_target_application", "add")
@validate_transaction_count(1)
def test_celery_task_apply():
"""
Expand All @@ -78,6 +82,7 @@ def test_celery_task_apply():
rollup_metrics=FORGONE_TASK_METRICS,
background_task=True,
)
@validate_code_level_metrics("_target_application", "add")
@validate_transaction_count(1)
def test_celery_task_delay():
"""
Expand All @@ -95,6 +100,7 @@ def test_celery_task_delay():
rollup_metrics=FORGONE_TASK_METRICS,
background_task=True,
)
@validate_code_level_metrics("_target_application", "add")
@validate_transaction_count(1)
def test_celery_task_apply_async():
"""
Expand All @@ -112,6 +118,7 @@ def test_celery_task_apply_async():
rollup_metrics=FORGONE_TASK_METRICS,
background_task=True,
)
@validate_code_level_metrics("_target_application", "add")
@validate_transaction_count(1)
def test_celery_app_send_task(celery_session_app):
"""
Expand All @@ -129,6 +136,7 @@ def test_celery_app_send_task(celery_session_app):
rollup_metrics=FORGONE_TASK_METRICS,
background_task=True,
)
@validate_code_level_metrics("_target_application", "add")
@validate_transaction_count(1)
def test_celery_task_signature():
"""
Expand All @@ -154,6 +162,8 @@ def test_celery_task_signature():
background_task=True,
index=-2,
)
@validate_code_level_metrics("_target_application", "add")
@validate_code_level_metrics("_target_application", "add", index=-2)
@validate_transaction_count(2)
def test_celery_task_link():
"""
Expand All @@ -179,12 +189,14 @@ def test_celery_task_link():
background_task=True,
index=-2,
)
@validate_code_level_metrics("_target_application", "add")
@validate_code_level_metrics("_target_application", "add", index=-2)
@validate_transaction_count(2)
def test_celery_chain():
"""
Executes multiple tasks on worker process and returns an AsyncResult.
"""
result = chain(add.s(3, 4), add.s(5))()
result = celery.chain(add.s(3, 4), add.s(5))()

result = result.get()
assert result == 12
Expand All @@ -205,12 +217,14 @@ def test_celery_chain():
background_task=True,
index=-2,
)
@validate_code_level_metrics("_target_application", "add")
@validate_code_level_metrics("_target_application", "add", index=-2)
@validate_transaction_count(2)
def test_celery_group():
"""
Executes multiple tasks on worker process and returns an AsyncResult.
"""
result = group(add.s(3, 4), add.s(1, 2))()
result = celery.group(add.s(3, 4), add.s(1, 2))()
result = result.get()
assert result == [7, 3]

Expand Down Expand Up @@ -238,12 +252,15 @@ def test_celery_group():
background_task=True,
index=-3,
)
@validate_code_level_metrics("_target_application", "tsum")
@validate_code_level_metrics("_target_application", "add", index=-2)
@validate_code_level_metrics("_target_application", "add", index=-3)
@validate_transaction_count(3)
def test_celery_chord():
"""
Executes 2 add tasks, followed by a tsum task on the worker process and returns an AsyncResult.
"""
result = chord([add.s(3, 4), add.s(1, 2)])(tsum.s())
result = celery.chord([add.s(3, 4), add.s(1, 2)])(tsum.s())
result = result.get()
assert result == 10

Expand All @@ -255,6 +272,7 @@ def test_celery_chord():
rollup_metrics=[("Function/_target_application.tsum", 2)],
background_task=True,
)
@validate_code_level_metrics("_target_application", "tsum", count=3)
@validate_transaction_count(1)
def test_celery_task_map():
"""
Expand All @@ -272,6 +290,7 @@ def test_celery_task_map():
rollup_metrics=[("Function/_target_application.add", 2)],
background_task=True,
)
@validate_code_level_metrics("_target_application", "add", count=3)
@validate_transaction_count(1)
def test_celery_task_starmap():
"""
Expand All @@ -297,6 +316,8 @@ def test_celery_task_starmap():
background_task=True,
index=-2,
)
@validate_code_level_metrics("_target_application", "add", count=2)
@validate_code_level_metrics("_target_application", "add", count=2, index=-2)
@validate_transaction_count(2)
def test_celery_task_chunks():
"""
Expand Down
46 changes: 46 additions & 0 deletions tests/application_celery/test_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2010 New Relic, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from _target_application import add

import celery

from newrelic.common.object_wrapper import _NRBoundFunctionWrapper


FORGONE_TASK_METRICS = [("Function/_target_application.add", None), ("Function/_target_application.tsum", None)]


def test_task_wrapping_detection():
"""
Ensure celery detects our monkeypatching properly and will run our instrumentation
on __call__ and runs that instead of micro-optimizing it away to a run() call.
If this is not working, most other tests in this file will fail as the different ways
of running celery tasks will not all run our instrumentation.
"""
assert celery.app.trace.task_has_custom(add, "__call__")


def test_worker_optimizations_preserve_instrumentation(celery_worker_available):
is_instrumented = lambda: isinstance(celery.app.task.BaseTask.__call__, _NRBoundFunctionWrapper)

celery.app.trace.reset_worker_optimizations()
assert is_instrumented(), "Instrumentation not initially applied."

celery.app.trace.setup_worker_optimizations(celery_worker_available.app)
assert is_instrumented(), "setup_worker_optimizations removed instrumentation."

celery.app.trace.reset_worker_optimizations()
assert is_instrumented(), "reset_worker_optimizations removed instrumentation."
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,22 @@


def validate_transaction_count(count):
_transactions = []
transactions = []

@transient_function_wrapper('newrelic.core.stats_engine',
'StatsEngine.record_transaction')
def _increment_count(wrapped, instance, args, kwargs):
_transactions.append(getattr(args[0], "name", True))
transactions.append(getattr(args[0], "name", True))
return wrapped(*args, **kwargs)

@function_wrapper
def _validate_transaction_count(wrapped, instance, args, kwargs):
_new_wrapped = _increment_count(wrapped)
result = _new_wrapped(*args, **kwargs)

_transactions = list(transactions)
del transactions[:] # Clear list for subsequent test runs

assert count == len(_transactions), (count, len(_transactions), _transactions)

return result
Expand Down

0 comments on commit b4e1de2

Please sign in to comment.