Skip to content

Commit

Permalink
Add CelStandardFunctions to allow environment subsetting for the runtime
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 677902385
  • Loading branch information
l46kok authored and copybara-github committed Nov 20, 2024
1 parent c0dcb67 commit f6592bb
Show file tree
Hide file tree
Showing 9 changed files with 2,503 additions and 95 deletions.
5 changes: 5 additions & 0 deletions common/src/main/java/dev/cel/common/CelException.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ public CelException(String message, Throwable cause) {
super(message, cause);
}

public CelException(String message, CelErrorCode errorCode) {
super(message);
this.errorCode = errorCode;
}

public CelException(String message, Throwable cause, CelErrorCode errorCode) {
super(message, cause);
this.errorCode = errorCode;
Expand Down
8 changes: 6 additions & 2 deletions runtime/src/main/java/dev/cel/runtime/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ java_library(
tags = [
],
deps = [
":runtime_helper",
"//:auto_value",
"//common",
"//common:error_codes",
Expand All @@ -61,13 +60,13 @@ java_library(
"//common/internal:safe_string_formatter",
"//common/types",
"//common/types:type_providers",
"//runtime:runtime_helper",
"@cel_spec//proto/cel/expr:expr_java_proto",
"@maven//:com_google_code_findbugs_annotations",
"@maven//:com_google_errorprone_error_prone_annotations",
"@maven//:com_google_guava_guava",
"@maven//:com_google_protobuf_protobuf_java",
"@maven//:com_google_protobuf_protobuf_java_util",
"@maven//:com_google_re2j_re2j",
"@maven//:org_jspecify_jspecify",
],
)
Expand Down Expand Up @@ -145,6 +144,7 @@ RUNTIME_SOURCES = [
"CelRuntimeFactory.java",
"CelRuntimeLegacyImpl.java",
"CelRuntimeLibrary.java",
"CelStandardFunctions.java",
"CelVariableResolver.java",
"HierarchicalVariableResolver.java",
"UnknownContext.java",
Expand All @@ -157,6 +157,7 @@ java_library(
],
deps = [
":evaluation_listener",
":runtime_helper",
":runtime_type_provider_legacy",
":unknown_attributes",
"//:auto_value",
Expand All @@ -165,6 +166,7 @@ java_library(
"//common:options",
"//common/annotations",
"//common/internal:cel_descriptor_pools",
"//common/internal:comparison_functions",
"//common/internal:default_message_factory",
"//common/internal:dynamic_proto",
"//common/internal:proto_message_factory",
Expand All @@ -176,6 +178,8 @@ java_library(
"@maven//:com_google_errorprone_error_prone_annotations",
"@maven//:com_google_guava_guava",
"@maven//:com_google_protobuf_protobuf_java",
"@maven//:com_google_protobuf_protobuf_java_util",
"@maven//:com_google_re2j_re2j",
"@maven//:org_jspecify_jspecify",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ public CelEvaluationException(String message, Throwable cause, CelErrorCode erro
super(message, cause, errorCode);
}

public CelEvaluationException(String message, CelErrorCode errorCode) {
super(message, errorCode);
}

CelEvaluationException(InterpreterException cause) {
super(cause.getMessage(), cause.getCause());
this(cause, cause.getErrorCode());
}

CelEvaluationException(InterpreterException cause, CelErrorCode errorCode) {
Expand Down
8 changes: 8 additions & 0 deletions runtime/src/main/java/dev/cel/runtime/CelRuntimeBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ public interface CelRuntimeBuilder {
@CanIgnoreReturnValue
CelRuntimeBuilder setStandardEnvironmentEnabled(boolean value);

/**
* Override the standard functions for the runtime. This can be used to subset the standard
* environment to only expose the desired function overloads to the runtime. {@link
* #setStandardEnvironmentEnabled(boolean)} must be set to false for this to take effect.
*/
@CanIgnoreReturnValue
CelRuntimeBuilder setStandardFunctions(CelStandardFunctions standardFunctions);

/** Adds one or more libraries for runtime. */
@CanIgnoreReturnValue
CelRuntimeBuilder addLibraries(CelRuntimeLibrary... libraries);
Expand Down
119 changes: 94 additions & 25 deletions runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
import dev.cel.common.types.CelTypes;
import dev.cel.common.values.CelValueProvider;
import dev.cel.common.values.ProtoMessageValueProvider;
import dev.cel.runtime.CelStandardFunctions.StandardFunction.Overload.Comparison;
import dev.cel.runtime.CelStandardFunctions.StandardFunction.Overload.Conversions;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -87,7 +89,7 @@ public static CelRuntimeBuilder newBuilder() {
public static final class Builder implements CelRuntimeBuilder {

private final ImmutableSet.Builder<FileDescriptor> fileTypes;
private final HashMap<String, CelFunctionBinding> functionBindings;
private final HashMap<String, CelFunctionBinding> customFunctionBindings;
private final ImmutableSet.Builder<CelRuntimeLibrary> celRuntimeLibraries;

@SuppressWarnings("unused")
Expand All @@ -97,6 +99,7 @@ public static final class Builder implements CelRuntimeBuilder {
private Function<String, Message.Builder> customTypeFactory;
private ExtensionRegistry extensionRegistry;
private CelValueProvider celValueProvider;
private CelStandardFunctions overriddenStandardFunctions;

@Override
public CelRuntimeBuilder setOptions(CelOptions options) {
Expand All @@ -111,7 +114,7 @@ public CelRuntimeBuilder addFunctionBindings(CelFunctionBinding... bindings) {

@Override
public CelRuntimeBuilder addFunctionBindings(Iterable<CelFunctionBinding> bindings) {
bindings.forEach(o -> functionBindings.putIfAbsent(o.getOverloadId(), o));
bindings.forEach(o -> customFunctionBindings.putIfAbsent(o.getOverloadId(), o));
return this;
}

Expand Down Expand Up @@ -160,6 +163,12 @@ public CelRuntimeBuilder setStandardEnvironmentEnabled(boolean value) {
return this;
}

@Override
public CelRuntimeBuilder setStandardFunctions(CelStandardFunctions standardFunctions) {
this.overriddenStandardFunctions = standardFunctions;
return this;
}

@Override
public CelRuntimeBuilder addLibraries(CelRuntimeLibrary... libraries) {
checkNotNull(libraries);
Expand All @@ -184,7 +193,7 @@ public CelRuntimeBuilder setExtensionRegistry(ExtensionRegistry extensionRegistr
// and shouldn't be exposed to the public.
@VisibleForTesting
Map<String, CelFunctionBinding> getFunctionBindings() {
return this.functionBindings;
return this.customFunctionBindings;
}

@VisibleForTesting
Expand All @@ -200,6 +209,11 @@ ImmutableSet.Builder<FileDescriptor> getFileTypes() {
/** Build a new {@code CelRuntimeLegacyImpl} instance from the builder config. */
@Override
public CelRuntimeLegacyImpl build() {
if (standardEnvironmentEnabled && overriddenStandardFunctions != null) {
throw new IllegalArgumentException(
"setStandardEnvironmentEnabled must be set to false to override standard function"
+ " bindings.");
}
// Add libraries, such as extensions
celRuntimeLibraries.build().forEach(celLibrary -> celLibrary.setRuntimeOptions(this));

Expand Down Expand Up @@ -227,26 +241,33 @@ public CelRuntimeLegacyImpl build() {

DynamicProto dynamicProto = DynamicProto.create(runtimeTypeFactory);

DefaultDispatcher dispatcher =
DefaultDispatcher.create(options, dynamicProto, standardEnvironmentEnabled);

ImmutableMap<String, CelFunctionBinding> functionBindingMap =
ImmutableMap.copyOf(functionBindings);
functionBindingMap.forEach(
(String overloadId, CelFunctionBinding func) ->
dispatcher.add(
overloadId,
func.getArgTypes(),
(args) -> {
try {
return func.getDefinition().apply(args);
} catch (CelEvaluationException e) {
throw new InterpreterException.Builder(e.getMessage())
.setCause(e)
.setErrorCode(e.getErrorCode())
.build();
}
}));
ImmutableMap.Builder<String, CelFunctionBinding> functionBindingsBuilder =
ImmutableMap.builder();
for (CelFunctionBinding standardFunctionBinding : newStandardFunctionBindings(dynamicProto)) {
functionBindingsBuilder.put(
standardFunctionBinding.getOverloadId(), standardFunctionBinding);
}

functionBindingsBuilder.putAll(customFunctionBindings);

DefaultDispatcher dispatcher = DefaultDispatcher.create();
functionBindingsBuilder
.buildOrThrow()
.forEach(
(String overloadId, CelFunctionBinding func) ->
dispatcher.add(
overloadId,
func.getArgTypes(),
(args) -> {
try {
return func.getDefinition().apply(args);
} catch (CelEvaluationException e) {
throw new InterpreterException.Builder(e.getMessage())
.setCause(e)
.setErrorCode(e.getErrorCode())
.build();
}
}));

RuntimeTypeProvider runtimeTypeProvider;

Expand All @@ -271,6 +292,54 @@ public CelRuntimeLegacyImpl build() {
this);
}

private ImmutableSet<CelFunctionBinding> newStandardFunctionBindings(
DynamicProto dynamicProto) {
CelStandardFunctions celStandardFunctions;
if (standardEnvironmentEnabled) {
celStandardFunctions =
CelStandardFunctions.newBuilder()
.filterFunctions(
(standardFunction, standardOverload) -> {
switch (standardFunction) {
case INT:
if (standardOverload.equals(Conversions.INT64_TO_INT64)) {
// Note that we require UnsignedLong flag here to avoid ambiguous
// overloads against "uint64_to_int64", because they both use the same
// Java Long class. We skip adding this identity function if the flag is
// disabled.
return options.enableUnsignedLongs();
}
break;
case TIMESTAMP:
// TODO: Remove this flag guard once the feature has been
// auto-enabled.
if (standardOverload.equals(Conversions.INT64_TO_TIMESTAMP)) {
return options.enableTimestampEpoch();
}
break;
default:
if (standardOverload instanceof Comparison
&& !options.enableHeterogeneousNumericComparisons()) {
Comparison comparison = (Comparison) standardOverload;
if (comparison.isHeterogeneousComparison()) {
return false;
}
}
break;
}

return true;
})
.build();
} else if (overriddenStandardFunctions != null) {
celStandardFunctions = overriddenStandardFunctions;
} else {
return ImmutableSet.of();
}

return celStandardFunctions.newFunctionBindings(dynamicProto, options);
}

private static CelDescriptorPool newDescriptorPool(
CelDescriptors celDescriptors,
ExtensionRegistry extensionRegistry) {
Expand All @@ -294,7 +363,7 @@ private static ProtoMessageFactory maybeCombineMessageFactory(
private Builder() {
this.options = CelOptions.newBuilder().build();
this.fileTypes = ImmutableSet.builder();
this.functionBindings = new HashMap<>();
this.customFunctionBindings = new HashMap<>();
this.celRuntimeLibraries = ImmutableSet.builder();
this.extensionRegistry = ExtensionRegistry.getEmptyRegistry();
this.customTypeFactory = null;
Expand All @@ -311,7 +380,7 @@ private Builder(Builder builder) {
// The following needs to be deep copied as they are collection builders
this.fileTypes = deepCopy(builder.fileTypes);
this.celRuntimeLibraries = deepCopy(builder.celRuntimeLibraries);
this.functionBindings = new HashMap<>(builder.functionBindings);
this.customFunctionBindings = new HashMap<>(builder.customFunctionBindings);
}

private static <T> ImmutableSet.Builder<T> deepCopy(ImmutableSet.Builder<T> builderToCopy) {
Expand Down
Loading

0 comments on commit f6592bb

Please sign in to comment.