Skip to content

Commit

Permalink
Surface type-mismatch errors in a readable fashion during rule compos…
Browse files Browse the repository at this point in the history
…ition

PiperOrigin-RevId: 652621800
  • Loading branch information
l46kok authored and copybara-github committed Jul 15, 2024
1 parent d20d377 commit 9f78ec3
Show file tree
Hide file tree
Showing 16 changed files with 316 additions and 61 deletions.
3 changes: 3 additions & 0 deletions policy/src/main/java/dev/cel/policy/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ java_library(
name = "compiled_rule",
srcs = ["CelCompiledRule.java"],
deps = [
":value_string",
"//:auto_value",
"//bundle:cel",
"//common",
Expand Down Expand Up @@ -301,12 +302,14 @@ java_library(
"//:auto_value",
"//bundle:cel",
"//common",
"//common:compiler_common",
"//common:mutable_ast",
"//common/ast",
"//extensions:optional_library",
"//optimizer:ast_optimizer",
"//optimizer:mutable_ast",
"//parser:operator",
"//policy:value_string",
"@maven//:com_google_errorprone_error_prone_annotations",
"@maven//:com_google_guava_guava",
],
Expand Down
28 changes: 24 additions & 4 deletions policy/src/main/java/dev/cel/policy/CelCompiledRule.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
import dev.cel.bundle.Cel;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelVarDecl;
import java.util.Optional;

/**
* Abstract representation of a compiled rule. This contains set of compiled variables and match
* statements which defines an expression graph for a policy.
*/
@AutoValue
public abstract class CelCompiledRule {
public abstract Optional<ValueString> id();

public abstract ImmutableList<CelCompiledVariable> variables();

public abstract ImmutableList<CelCompiledMatch> matches();
Expand Down Expand Up @@ -63,14 +66,15 @@ public abstract static class CelCompiledMatch {
/** Encapsulates the result of this match when condition is met. (either an output or a rule) */
@AutoOneOf(CelCompiledMatch.Result.Kind.class)
public abstract static class Result {
public abstract CelAbstractSyntaxTree output();
public abstract OutputValue output();

public abstract CelCompiledRule rule();

public abstract Kind kind();

static Result ofOutput(CelAbstractSyntaxTree value) {
return AutoOneOf_CelCompiledRule_CelCompiledMatch_Result.output(value);
static Result ofOutput(long id, CelAbstractSyntaxTree ast) {
return AutoOneOf_CelCompiledRule_CelCompiledMatch_Result.output(
OutputValue.create(id, ast));
}

static Result ofRule(CelCompiledRule value) {
Expand All @@ -84,16 +88,32 @@ public enum Kind {
}
}

/**
* Encapsulates the output value of the match with its original ID that was used to compile
* with.
*/
@AutoValue
public abstract static class OutputValue {
public abstract long id();

public abstract CelAbstractSyntaxTree ast();

public static OutputValue create(long id, CelAbstractSyntaxTree ast) {
return new AutoValue_CelCompiledRule_CelCompiledMatch_OutputValue(id, ast);
}
}

static CelCompiledMatch create(
CelAbstractSyntaxTree condition, CelCompiledMatch.Result result) {
return new AutoValue_CelCompiledRule_CelCompiledMatch(condition, result);
}
}

static CelCompiledRule create(
Optional<ValueString> id,
ImmutableList<CelCompiledVariable> variables,
ImmutableList<CelCompiledMatch> matches,
Cel cel) {
return new AutoValue_CelCompiledRule(variables, matches, cel);
return new AutoValue_CelCompiledRule(id, variables, matches, cel);
}
}
5 changes: 3 additions & 2 deletions policy/src/main/java/dev/cel/policy/CelPolicyCompiler.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public interface CelPolicyCompiler {
* CEL environment.
*/
default CelAbstractSyntaxTree compile(CelPolicy policy) throws CelPolicyValidationException {
return compose(compileRule(policy));
return compose(policy, compileRule(policy));
}

/**
Expand All @@ -40,5 +40,6 @@ default CelAbstractSyntaxTree compile(CelPolicy policy) throws CelPolicyValidati
* Composes {@link CelCompiledRule}, representing an expression graph, into a single expression
* value.
*/
CelAbstractSyntaxTree compose(CelCompiledRule compiledRule) throws CelPolicyValidationException;
CelAbstractSyntaxTree compose(CelPolicy policy, CelCompiledRule compiledRule)
throws CelPolicyValidationException;
}
50 changes: 43 additions & 7 deletions policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package dev.cel.policy;

import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.collect.ImmutableList.toImmutableList;

import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
Expand All @@ -37,7 +38,9 @@
import dev.cel.policy.CelCompiledRule.CelCompiledVariable;
import dev.cel.policy.CelPolicy.Match;
import dev.cel.policy.CelPolicy.Variable;
import dev.cel.policy.RuleComposer.RuleCompositionException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;

Expand All @@ -61,7 +64,7 @@ public CelCompiledRule compileRule(CelPolicy policy) throws CelPolicyValidationE
}

@Override
public CelAbstractSyntaxTree compose(CelCompiledRule compiledRule)
public CelAbstractSyntaxTree compose(CelPolicy policy, CelCompiledRule compiledRule)
throws CelPolicyValidationException {
CelOptimizer optimizer =
CelOptimizerFactory.standardCelOptimizerBuilder(compiledRule.cel())
Expand All @@ -77,8 +80,28 @@ public CelAbstractSyntaxTree compose(CelCompiledRule compiledRule)
ast = cel.compile("true").getAst();
ast = optimizer.optimize(ast);
} catch (CelValidationException | CelOptimizationException e) {
// TODO: Surface these errors better
throw new CelPolicyValidationException("Failed composing the rules", e);
if (e.getCause() instanceof RuleCompositionException) {
RuleCompositionException re = (RuleCompositionException) e.getCause();
CompilerContext compilerContext = new CompilerContext(policy.policySource());
// The exact CEL error message produced from composition failure isn't too useful for users.
// Ex: ERROR: :1:1: found no matching overload for '_?_:_' applied to '(bool, map(int, int),
// bool)' (candidates: (bool, %A0, %A0))
// Transform the error messages in a user-friendly way while retaining the original
// CelValidationException as its originating cause.

ImmutableList<CelIssue> transformedIssues =
re.compileException.getErrors().stream()
.map(x -> CelIssue.formatError(x.getSourceLocation(), re.failureReason))
.collect(toImmutableList());
for (long id : re.errorIds) {
compilerContext.addIssue(id, transformedIssues);
}

throw new CelPolicyValidationException(compilerContext.getIssueString(), re.getCause());
}

// Something has gone seriously wrong.
throw new CelPolicyValidationException("Unexpected error while composing rules.", e);
}

return ast;
Expand Down Expand Up @@ -111,6 +134,11 @@ private CelCompiledRule compileRuleImpl(
CelAbstractSyntaxTree conditionAst;
try {
conditionAst = ruleCel.compile(match.condition().value()).getAst();
if (!conditionAst.getResultType().equals(SimpleType.BOOL)) {
compilerContext.addIssue(
match.condition().id(),
CelIssue.formatError(1, 0, "condition must produce a boolean output."));
}
} catch (CelValidationException e) {
compilerContext.addIssue(match.condition().id(), e.getErrors());
continue;
Expand All @@ -120,14 +148,15 @@ private CelCompiledRule compileRuleImpl(
switch (match.result().kind()) {
case OUTPUT:
CelAbstractSyntaxTree outputAst;
ValueString output = match.result().output();
try {
outputAst = ruleCel.compile(match.result().output().value()).getAst();
outputAst = ruleCel.compile(output.value()).getAst();
} catch (CelValidationException e) {
compilerContext.addIssue(match.result().output().id(), e.getErrors());
compilerContext.addIssue(output.id(), e.getErrors());
continue;
}

matchResult = Result.ofOutput(outputAst);
matchResult = Result.ofOutput(output.id(), outputAst);
break;
case RULE:
CelCompiledRule nestedRule =
Expand All @@ -141,7 +170,7 @@ private CelCompiledRule compileRuleImpl(
matchBuilder.add(CelCompiledMatch.create(conditionAst, matchResult));
}

return CelCompiledRule.create(variableBuilder.build(), matchBuilder.build(), cel);
return CelCompiledRule.create(rule.id(), variableBuilder.build(), matchBuilder.build(), cel);
}

private static CelAbstractSyntaxTree newErrorAst() {
Expand All @@ -153,6 +182,10 @@ private static final class CompilerContext {
private final ArrayList<CelIssue> issues;
private final CelPolicySource celPolicySource;

private void addIssue(long id, CelIssue... issues) {
addIssue(id, Arrays.asList(issues));
}

private void addIssue(long id, List<CelIssue> issues) {
for (CelIssue issue : issues) {
CelSourceLocation absoluteLocation = computeAbsoluteLocation(id, issue);
Expand All @@ -163,6 +196,9 @@ private void addIssue(long id, List<CelIssue> issues) {
private CelSourceLocation computeAbsoluteLocation(long id, CelIssue issue) {
int policySourceOffset =
Optional.ofNullable(celPolicySource.getPositionsMap().get(id)).orElse(-1);
if (policySourceOffset == -1) {
return CelSourceLocation.NONE;
}
CelSourceLocation policySourceLocation =
celPolicySource.getOffsetLocation(policySourceOffset).orElse(null);
if (policySourceLocation == null) {
Expand Down
75 changes: 64 additions & 11 deletions policy/src/main/java/dev/cel/policy/RuleComposer.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,25 @@

import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.stream.Collectors.toCollection;

import com.google.auto.value.AutoValue;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import dev.cel.bundle.Cel;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelMutableAst;
import dev.cel.common.CelValidationException;
import dev.cel.common.ast.CelConstant.Kind;
import dev.cel.extensions.CelOptionalLibrary.Function;
import dev.cel.optimizer.AstMutator;
import dev.cel.optimizer.CelAstOptimizer;
import dev.cel.parser.Operator;
import dev.cel.policy.CelCompiledRule.CelCompiledMatch;
import dev.cel.policy.CelCompiledRule.CelCompiledMatch.OutputValue;
import dev.cel.policy.CelCompiledRule.CelCompiledVariable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/** Package-private class for composing various rules into a single expression using optimizer. */
final class RuleComposer implements CelAstOptimizer {
Expand All @@ -39,13 +44,8 @@ final class RuleComposer implements CelAstOptimizer {

@Override
public OptimizationResult optimize(CelAbstractSyntaxTree ast, Cel cel) {
RuleOptimizationResult result = optimizeRule(compiledRule);
return OptimizationResult.create(
result.ast().toParsedAst(),
compiledRule.variables().stream()
.map(CelCompiledVariable::celVarDecl)
.collect(toImmutableList()),
ImmutableList.of());
RuleOptimizationResult result = optimizeRule(cel, compiledRule);
return OptimizationResult.create(result.ast().toParsedAst());
}

@AutoValue
Expand All @@ -59,20 +59,33 @@ static RuleOptimizationResult create(CelMutableAst ast, boolean isOptionalResult
}
}

private RuleOptimizationResult optimizeRule(CelCompiledRule compiledRule) {
private RuleOptimizationResult optimizeRule(Cel cel, CelCompiledRule compiledRule) {
cel =
cel.toCelBuilder()
.addVarDeclarations(
compiledRule.variables().stream()
.map(CelCompiledVariable::celVarDecl)
.collect(toImmutableList()))
.build();

CelMutableAst matchAst = astMutator.newGlobalCall(Function.OPTIONAL_NONE.getFunction());
boolean isOptionalResult = true;
// Keep track of the last output ID that might cause type-check failure while attempting to
// compose the subgraphs.
long lastOutputId = 0;
for (CelCompiledMatch match : Lists.reverse(compiledRule.matches())) {
CelAbstractSyntaxTree conditionAst = match.condition();
boolean isTriviallyTrue =
conditionAst.getExpr().constantOrDefault().getKind().equals(Kind.BOOLEAN_VALUE)
&& conditionAst.getExpr().constant().booleanValue();
switch (match.result().kind()) {
case OUTPUT:
CelMutableAst outAst = CelMutableAst.fromCelAst(match.result().output());
OutputValue matchOutput = match.result().output();
CelMutableAst outAst = CelMutableAst.fromCelAst(matchOutput.ast());
if (isTriviallyTrue) {
matchAst = outAst;
isOptionalResult = false;
lastOutputId = matchOutput.id();
continue;
}
if (isOptionalResult) {
Expand All @@ -85,9 +98,13 @@ private RuleOptimizationResult optimizeRule(CelCompiledRule compiledRule) {
CelMutableAst.fromCelAst(conditionAst),
outAst,
matchAst);
assertComposedAstIsValid(
cel, matchAst, "conflicting output types found.", matchOutput.id(), lastOutputId);
lastOutputId = matchOutput.id();
continue;
case RULE:
RuleOptimizationResult nestedRule = optimizeRule(match.result().rule());
CelCompiledRule matchNestedRule = match.result().rule();
RuleOptimizationResult nestedRule = optimizeRule(cel, matchNestedRule);
CelMutableAst nestedRuleAst = nestedRule.ast();
if (isOptionalResult && !nestedRule.isOptionalResult()) {
nestedRuleAst =
Expand All @@ -101,6 +118,13 @@ private RuleOptimizationResult optimizeRule(CelCompiledRule compiledRule) {
throw new IllegalArgumentException("Subrule early terminates policy");
}
matchAst = astMutator.newMemberCall(nestedRuleAst, Function.OR.getFunction(), matchAst);
assertComposedAstIsValid(
cel,
matchAst,
String.format(
"failed composing the subrule '%s' due to conflicting output types.",
matchNestedRule.id().map(ValueString::value).orElse("")),
lastOutputId);
break;
}
}
Expand All @@ -127,9 +151,38 @@ static RuleComposer newInstance(
return new RuleComposer(compiledRule, variablePrefix, iterationLimit);
}

private void assertComposedAstIsValid(
Cel cel, CelMutableAst composedAst, String failureMessage, Long... ids) {
assertComposedAstIsValid(cel, composedAst, failureMessage, Arrays.asList(ids));
}

private void assertComposedAstIsValid(
Cel cel, CelMutableAst composedAst, String failureMessage, List<Long> ids) {
try {
cel.check(composedAst.toParsedAst()).getAst();
} catch (CelValidationException e) {
ids = ids.stream().filter(id -> id > 0).collect(toCollection(ArrayList::new));
throw new RuleCompositionException(failureMessage, e, ids);
}
}

private RuleComposer(CelCompiledRule compiledRule, String variablePrefix, int iterationLimit) {
this.compiledRule = checkNotNull(compiledRule);
this.variablePrefix = variablePrefix;
this.astMutator = AstMutator.newInstance(iterationLimit);
}

static final class RuleCompositionException extends RuntimeException {
final String failureReason;
final List<Long> errorIds;
final CelValidationException compileException;

private RuleCompositionException(
String failureReason, CelValidationException e, List<Long> errorIds) {
super(e);
this.failureReason = failureReason;
this.errorIds = errorIds;
this.compileException = e;
}
}
}
Loading

0 comments on commit 9f78ec3

Please sign in to comment.