Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llvm: _comp_cached: handle weakref proxy/ref caching #2674

Open
wants to merge 1 commit into
base: devel
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion psyneulink/core/llvm/builder_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from psyneulink.core.scheduling.time import Time, TimeScale
from psyneulink.core.globals.sampleiterator import SampleIterator
from psyneulink.core.globals.utilities import ContentAddressableList
from psyneulink.core.globals.utilities import ContentAddressableList, unproxy_weakproxy
from psyneulink.core import llvm as pnlvm

from . import codegen
Expand Down Expand Up @@ -75,6 +75,13 @@ def _gen_llvm_function(self, *, ctx, tags:frozenset):
def _comp_cached(func):
@functools.wraps(func)
def wrapper(bctx, obj):
if isinstance(obj, weakref.ProxyTypes):
# only call for ProxyTypes because this won't fail on most
# objects, but specifically not on 'super()' referenced
# below, which would return the original object super() was
# called with, resulting in caching the wrong thing here
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit confusing. What is the proxy object observed here? and how is it related to autodiff composition?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The autodiff relation is incidental - it's only how I happened to notice what seemed like unintentional cache misses.

The wrapper in _comp_cached stores a value for an obj, and later is called weakref.proxy objects that reference obj. This does not locate obj in the cache due to TypeError("cannot create weak reference to 'weakcallableproxy' object") which is caught

def wrapper(bctx, obj):
try:
obj_cache = bctx._cache.setdefault(obj, dict())
except TypeError: # 'super()' references can't be cached
obj_cache = None

and bypasses caching.

The example I found creating these proxies is

class _node_wrapper():
def __init__(self, composition, node):
self._comp = weakref.proxy(composition)

which was added after _comp_cached

Copy link
Collaborator

@jvesely jvesely May 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. the change to use weakref.proxy (c65e0bd) introduced the bug in pr #2613. I should have checked the total number of generated structures.
I'm still unsure about supporting proxy object caching vs. just reverting c65e0bd.

Does this change fix the high number of generated structures in test_training_then_processing? I'd expect that test (and all autodiff compositions) to run into the 'super()' issue instead.

EDIT: To elaborate. It's interesting that the unproxy_weakproxy function works on super objects as well [0].
We could use it to address the super() codepath as well as the issue introduced in c65e0bd. This would also allow us to remove the entire exception block in comp_cached.
The fewer isinstance checks and exception blocks on fast paths, the better.

Otherwise, I think it'd be better to just revert c65e0bd.

[0] https://docs.python.org/3/library/functions.html#super

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this change fix the high number of generated structures in test_training_then_processing? I'd expect that test (and all autodiff compositions) to run into the 'super()' issue instead.

Could you let me know how to check this?
I'm only aware of repeated calls to

def _get_state_struct_type(self, ctx):
comp_state_type_list = ctx.get_state_struct_type(super())
pytorch_representation = self._build_pytorch_representation()
optimizer_state_type = pytorch_representation._get_compiled_optimizer()._get_optimizer_struct_type(ctx)
return pnlvm.ir.LiteralStructType((
*comp_state_type_list,
optimizer_state_type))

I missed at first that bctx._cache is a WeakKeyDictionary - in that case, it doesn't seem like a problem to just call unproxy_weakproxy each time to let the super objects be stored as well, but I'm not too sure about the intent or benefits of this cache so I'll defer to you.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a good question. the comment above mentioned "AutodiffComposition._get_state_struct_type gets called many times" so I thought you had some monitoring set up.
Either way, it prodded me to fix the stats collection for code generation which has been fixed/extended in #2687.
you should be able to get some numbers by enabling printouts via PNL_LLVM_DEBUG=stat. running tests might need -n0 or pytest might hide the output.

the _comp_cached wrapper is a generalized caching decorator for binary structure types used by compiled functions. any time there's a call to get_*_struct_type it can be cached. There are many repeated calls to get the same structure because the structure construction is often recursive.

specifically "node wrapper" is a pseudo object that represents node and all afferent projections. it gets compiled into a single function and reuses the same data types as composition execute and run. Thus generating node wrapper IR code call composition get_*_struct_type.
There are a few places that call get_.*_struct_type(super()), so I'm not sure what caching will do with those. It can "just work", but it might need a closer look

obj = unproxy_weakproxy(obj)

try:
obj_cache = bctx._cache.setdefault(obj, dict())
except TypeError: # 'super()' references can't be cached
Expand Down