Skip to content

Commit

Permalink
Fix comp graph pickling
Browse files Browse the repository at this point in the history
  • Loading branch information
evhub committed Oct 19, 2024
1 parent 7187d5c commit c439630
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 68 deletions.
93 changes: 43 additions & 50 deletions coconut/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@
from threading import Lock
from copy import copy

if sys.version_info >= (3,):
import pickle
else:
import cPickle as pickle

from coconut._pyparsing import (
USE_COMPUTATION_GRAPH,
USE_CACHE,
Expand Down Expand Up @@ -109,8 +104,10 @@
incremental_mode_cache_size,
incremental_cache_limit,
use_line_by_line_parser,
coconut_cache_dir,
)
from coconut.util import (
pickle,
pickleable_obj,
checksum,
clip,
Expand All @@ -126,6 +123,7 @@
create_method,
univ_open,
staledict,
ensure_dir,
)
from coconut.exceptions import (
CoconutException,
Expand Down Expand Up @@ -161,6 +159,7 @@
ComputationNode,
StartOfStrGrammar,
MatchAny,
CombineToNode,
sys_target,
getline,
addskip,
Expand Down Expand Up @@ -210,7 +209,8 @@
get_cache_items_for,
clear_packrat_cache,
add_packrat_cache_items,
get_cache_path,
parse_elem_to_identifier,
identifier_to_parse_elem,
_lookup_loc,
_value_exc_loc_or_ret,
)
Expand Down Expand Up @@ -447,10 +447,7 @@ def pickle_cache(original, cache_path, include_incremental=True, protocol=pickle
# are the only ones that parseIncremental will reuse
if 0 < loc < len(original) - 1:
elem = lookup[0]
identifier = elem.parse_element_index
internal_assert(lambda: elem == all_parse_elements[identifier](), "failed to look up parse element by identifier", lambda: (elem, all_parse_elements[identifier]()))
if validation_dict is not None:
validation_dict[identifier] = elem.__class__.__name__
identifier = parse_elem_to_identifier(elem, validation_dict)
pickleable_lookup = (identifier,) + lookup[1:]
internal_assert(value[_value_exc_loc_or_ret] is True or isinstance(value[_value_exc_loc_or_ret], int), "cache must be dehybridized before pickling", value[_value_exc_loc_or_ret])
pickleable_cache_items.append((pickleable_lookup, value))
Expand All @@ -460,21 +457,15 @@ def pickle_cache(original, cache_path, include_incremental=True, protocol=pickle
for wkref in MatchAny.all_match_anys:
match_any = wkref()
if match_any is not None and match_any.adaptive_usage is not None:
identifier = match_any.parse_element_index
internal_assert(lambda: match_any == all_parse_elements[identifier](), "failed to look up match_any by identifier", lambda: (match_any, all_parse_elements[identifier]()))
if validation_dict is not None:
validation_dict[identifier] = match_any.__class__.__name__
identifier = parse_elem_to_identifier(match_any, validation_dict)
match_any.expr_order.sort(key=lambda i: (-match_any.adaptive_usage[i], i))
all_adaptive_items.append((identifier, (match_any.adaptive_usage, match_any.expr_order)))
logger.log("Caching adaptive item:", match_any, (match_any.adaptive_usage, match_any.expr_order))

# computation graph cache
computation_graph_cache_items = []
for (call_site_name, grammar_elem), cache in Compiler.computation_graph_caches.items():
identifier = grammar_elem.parse_element_index
internal_assert(lambda: grammar_elem == all_parse_elements[identifier](), "failed to look up grammar by identifier", lambda: (grammar_elem, all_parse_elements[identifier]()))
if validation_dict is not None:
validation_dict[identifier] = grammar_elem.__class__.__name__
identifier = parse_elem_to_identifier(grammar_elem, validation_dict)
computation_graph_cache_items.append(((call_site_name, identifier), cache))

logger.log("Saving {num_inc} incremental, {num_adapt} adaptive, and {num_comp_graph} computation graph cache items to {cache_path!r}.".format(
Expand All @@ -492,8 +483,9 @@ def pickle_cache(original, cache_path, include_incremental=True, protocol=pickle
"computation_graph_cache_items": computation_graph_cache_items,
}
try:
with univ_open(cache_path, "wb") as pickle_file:
pickle.dump(pickle_info_obj, pickle_file, protocol=protocol)
with CombineToNode.enable_pickling(validation_dict):
with univ_open(cache_path, "wb") as pickle_file:
pickle.dump(pickle_info_obj, pickle_file, protocol=protocol)
except Exception:
logger.log_exc()
return False
Expand Down Expand Up @@ -531,15 +523,25 @@ def unpickle_cache(cache_path):
all_adaptive_items = pickle_info_obj["all_adaptive_items"]
computation_graph_cache_items = pickle_info_obj["computation_graph_cache_items"]

# incremental cache
new_cache_items = []
for pickleable_lookup, value in pickleable_cache_items:
maybe_elem = identifier_to_parse_elem(pickleable_lookup[0], validation_dict)
if maybe_elem is not None:
internal_assert(value[_value_exc_loc_or_ret] is True or isinstance(value[_value_exc_loc_or_ret], int), "attempting to unpickle hybrid cache item", value[_value_exc_loc_or_ret])
lookup = (maybe_elem,) + pickleable_lookup[1:]
usefullness = value[-1][0]
internal_assert(usefullness, "loaded useless cache item", (lookup, value))
stale_value = value[:-1] + ([usefullness + 1],)
new_cache_items.append((lookup, stale_value))
add_packrat_cache_items(new_cache_items)

# adaptive cache
for identifier, (adaptive_usage, expr_order) in all_adaptive_items:
if identifier < len(all_parse_elements):
maybe_elem = all_parse_elements[identifier]()
if maybe_elem is not None:
if validation_dict is not None:
internal_assert(maybe_elem.__class__.__name__ == validation_dict[identifier], "adaptive cache pickle-unpickle inconsistency", (maybe_elem, validation_dict[identifier]))
maybe_elem.adaptive_usage = adaptive_usage
maybe_elem.expr_order = expr_order
maybe_elem = identifier_to_parse_elem(identifier, validation_dict)
if maybe_elem is not None:
maybe_elem.adaptive_usage = adaptive_usage
maybe_elem.expr_order = expr_order

max_cache_size = min(
incremental_mode_cache_size or float("inf"),
Expand All @@ -548,38 +550,29 @@ def unpickle_cache(cache_path):
if max_cache_size != float("inf"):
pickleable_cache_items = pickleable_cache_items[-max_cache_size:]

# incremental cache
new_cache_items = []
for pickleable_lookup, value in pickleable_cache_items:
identifier = pickleable_lookup[0]
if identifier < len(all_parse_elements):
maybe_elem = all_parse_elements[identifier]()
if maybe_elem is not None:
if validation_dict is not None:
internal_assert(maybe_elem.__class__.__name__ == validation_dict[identifier], "incremental cache pickle-unpickle inconsistency", (maybe_elem, validation_dict[identifier]))
internal_assert(value[_value_exc_loc_or_ret] is True or isinstance(value[_value_exc_loc_or_ret], int), "attempting to unpickle hybrid cache item", value[_value_exc_loc_or_ret])
lookup = (maybe_elem,) + pickleable_lookup[1:]
usefullness = value[-1][0]
internal_assert(usefullness, "loaded useless cache item", (lookup, value))
stale_value = value[:-1] + ([usefullness + 1],)
new_cache_items.append((lookup, stale_value))
add_packrat_cache_items(new_cache_items)

# computation graph cache
for (call_site_name, identifier), cache in computation_graph_cache_items:
if identifier < len(all_parse_elements):
maybe_elem = all_parse_elements[identifier]()
if maybe_elem is not None:
if validation_dict is not None:
internal_assert(maybe_elem.__class__.__name__ == validation_dict[identifier], "computation graph cache pickle-unpickle inconsistency", (maybe_elem, validation_dict[identifier]))
Compiler.computation_graph_caches[(call_site_name, maybe_elem)].update(cache)
maybe_elem = identifier_to_parse_elem(identifier, validation_dict)
if maybe_elem is not None:
Compiler.computation_graph_caches[(call_site_name, maybe_elem)].update(cache)

num_inc = len(pickleable_cache_items)
num_adapt = len(all_adaptive_items)
num_comp_graph = sum(len(cache) for _, cache in computation_graph_cache_items) if computation_graph_cache_items else 0
return num_inc, num_adapt, num_comp_graph


def get_cache_path(codepath):
"""Get the cache filename to use for the given codepath."""
code_dir, code_fname = os.path.split(codepath)

cache_dir = os.path.join(code_dir, coconut_cache_dir)
ensure_dir(cache_dir, logger=logger)

pickle_fname = code_fname + ".pkl"
return os.path.join(cache_dir, pickle_fname)


def load_cache_for(inputstring, codepath):
"""Load cache_path (for the given inputstring and filename)."""
if not SUPPORTS_INCREMENTAL:
Expand Down
69 changes: 54 additions & 15 deletions coconut/compiler/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from coconut.root import * # NOQA

import sys
import os
import re
import ast
import inspect
Expand All @@ -49,7 +48,6 @@
SUPPORTS_INCREMENTAL,
SUPPORTS_ADAPTIVE,
SUPPORTS_PACKRAT_CONTEXT,
replaceWith,
ZeroOrMore,
OneOrMore,
Optional,
Expand Down Expand Up @@ -77,13 +75,15 @@

from coconut.integrations import embed
from coconut.util import (
pickle,
override,
get_name,
get_target_info,
memoize,
ensure_dir,
get_clock_time,
literal_lines,
const,
pickleable_obj,
)
from coconut.terminal import (
logger,
Expand Down Expand Up @@ -120,14 +120,14 @@
incremental_cache_limit,
incremental_mode_cache_successes,
use_adaptive_any_of,
coconut_cache_dir,
use_fast_pyparsing_reprs,
require_cache_clear_frac,
reverse_any_of,
all_keywords,
always_keep_parse_name_prefix,
keep_if_unchanged_parse_name_prefix,
incremental_use_hybrid,
test_computation_graph_pickling,
)
from coconut.exceptions import (
CoconutException,
Expand Down Expand Up @@ -315,7 +315,7 @@ def build_new_toks_for(tokens, new_toklist, unchanged=False):
cached_trim_arity = memoize()(_trim_arity)


class ComputationNode(object):
class ComputationNode(pickleable_obj):
"""A single node in the computation graph."""
__slots__ = ("action", "original", "loc", "tokens", "trim_arity")
pprinting = False
Expand All @@ -339,6 +339,12 @@ def __new__(cls, action, original, loc, tokens, trim_arity=True, ignore_no_token
If ignore_no_tokens, then don't call the action if there are no tokens.
If ignore_one_token, then don't call the action if there is only one token.
If greedy, then never defer the action until later."""
if test_computation_graph_pickling:
with CombineToNode.enable_pickling():
try:
pickle.dumps(action, protocol=pickle.HIGHEST_PROTOCOL)
except Exception:
raise ValueError("unpickleable action in ComputationNode: " + repr(action))
if ignore_no_tokens and len(tokens) == 0 or ignore_one_token and len(tokens) == 1:
# could be a ComputationNode, so we can't have an __init__
return build_new_toks_for(tokens, tokens, unchanged=True)
Expand Down Expand Up @@ -452,9 +458,10 @@ def evaluate(self):
raise self.exception_maker()


class CombineToNode(Combine):
class CombineToNode(Combine, pickleable_obj):
"""Modified Combine to work with the computation graph."""
__slots__ = ()
validation_dict = None

def _combine(self, original, loc, tokens):
"""Implement the parse action for Combine."""
Expand All @@ -468,6 +475,26 @@ def postParse(self, original, loc, tokens):
"""Create a ComputationNode for Combine."""
return ComputationNode(self._combine, original, loc, tokens, ignore_no_tokens=True, ignore_one_token=True, trim_arity=False)

@classmethod
def reconstitute(self, identifier):
return identifier_to_parse_elem(identifier, self.validation_dict)

def __reduce__(self):
if self.validation_dict is None:
return super(CombineToNode, self).__reduce__()
else:
return (self.reconstitute, (parse_elem_to_identifier(self, self.validation_dict),))

@classmethod
@contextmanager
def enable_pickling(validation_dict={}):
"""Context manager to enable pickling for CombineToNode."""
old_validation_dict, CombineToNode.validation_dict = CombineToNode.validation_dict, validation_dict
try:
yield
finally:
CombineToNode.validation_dict = old_validation_dict


if USE_COMPUTATION_GRAPH:
combine = CombineToNode
Expand Down Expand Up @@ -1136,15 +1163,24 @@ def disable_incremental_parsing():
force_reset_packrat_cache()


def get_cache_path(codepath):
"""Get the cache filename to use for the given codepath."""
code_dir, code_fname = os.path.split(codepath)
def parse_elem_to_identifier(elem, validation_dict=None):
"""Get the identifier for the given parse element."""
identifier = elem.parse_element_index
internal_assert(lambda: elem == all_parse_elements[identifier](), "failed to look up parse element by identifier", lambda: (elem, all_parse_elements[identifier]()))
if validation_dict is not None:
validation_dict[identifier] = elem.__class__.__name__
return identifier

cache_dir = os.path.join(code_dir, coconut_cache_dir)
ensure_dir(cache_dir, logger=logger)

pickle_fname = code_fname + ".pkl"
return os.path.join(cache_dir, pickle_fname)
def identifier_to_parse_elem(identifier, validation_dict=None):
"""Get the parse element for the given identifier."""
if identifier < len(all_parse_elements):
maybe_elem = all_parse_elements[identifier]()
if maybe_elem is not None:
if validation_dict is not None:
internal_assert(maybe_elem.__class__.__name__ == validation_dict[identifier], "parse element pickle-unpickle inconsistency", (maybe_elem, validation_dict[identifier]))
return maybe_elem
return None


# -----------------------------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -1350,11 +1386,14 @@ def add_labels(tokens):
return (item, tokens._ParseResults__tokdict.keys())


def invalid_syntax_handle(msg, loc, tokens):
def invalid_syntax_handle(msg, original, loc, tokens):
"""Pickleable handler for invalid_syntax."""
raise CoconutDeferredSyntaxError(msg, loc)


invalid_syntax_handle.trim_arity = False # fixes pypy issue


def invalid_syntax(item, msg, **kwargs):
"""Mark a grammar item as an invalid item that raises a syntax err with msg."""
if isinstance(item, str):
Expand Down Expand Up @@ -1405,7 +1444,7 @@ def regex_item(regex, options=None):

def fixto(item, output):
"""Force an item to result in a specific output."""
return attach(item, replaceWith(output), ignore_arguments=True)
return attach(item, const([output]), ignore_arguments=True)


def addspace(item):
Expand Down
3 changes: 2 additions & 1 deletion coconut/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,9 @@ def get_path_env_var(env_var, default):
# COMPILER CONSTANTS:
# -----------------------------------------------------------------------------------------------------------------------

# set this to True only ever temporarily for ease of debugging
# set these to True only ever temporarily for ease of debugging
embed_on_internal_exc = get_bool_env_var("COCONUT_EMBED_ON_INTERNAL_EXC", False)
test_computation_graph_pickling = False

# should be the minimal ref count observed by maybe_copy_elem
temp_grammar_item_ref_count = 4 if PY311 else 5
Expand Down
3 changes: 2 additions & 1 deletion coconut/tests/constants_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ class TestConstants(unittest.TestCase):

def test_defaults(self):
assert constants.use_fast_pyparsing_reprs
assert not constants.embed_on_internal_exc
assert constants.num_assemble_logical_lines_tries >= 1
assert not constants.embed_on_internal_exc
assert not constants.test_computation_graph_pickling

def test_fixpath(self):
assert os.path.basename(fixpath("CamelCase.py")) == "CamelCase.py"
Expand Down
6 changes: 5 additions & 1 deletion coconut/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
from backports.functools_lru_cache import lru_cache
except ImportError:
lru_cache = None
if sys.version_info >= (3,):
import pickle # NOQA
else:
import cPickle as pickle # NOQA

from coconut.root import _get_target_info
from coconut.constants import (
Expand Down Expand Up @@ -286,7 +290,7 @@ def add(self, item):
self[item] = True


class staledict(dict, object):
class staledict(dict, pickleable_obj):
"""A dictionary that keeps track of staleness.
Initial elements are always marked as stale and pickling always prunes stale elements."""
Expand Down

0 comments on commit c439630

Please sign in to comment.