Skip to content

Commit

Permalink
Merge pull request #42428 from chiranSachintha/issue-42356
Browse files Browse the repository at this point in the history
Generate closures for default values in type inclusions
  • Loading branch information
LakshanWeerasinghe authored Apr 5, 2024
2 parents c4cd054 + dfde2e3 commit 7782a44
Show file tree
Hide file tree
Showing 12 changed files with 197 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.wso2.ballerinalang.compiler.semantics.model.symbols.SymTag;
import org.wso2.ballerinalang.compiler.semantics.model.symbols.Symbols;
import org.wso2.ballerinalang.compiler.semantics.model.types.BInvokableType;
import org.wso2.ballerinalang.compiler.semantics.model.types.BRecordType;
import org.wso2.ballerinalang.compiler.semantics.model.types.BType;
import org.wso2.ballerinalang.compiler.tree.BLangAnnotation;
import org.wso2.ballerinalang.compiler.tree.BLangAnnotationAttachment;
Expand Down Expand Up @@ -203,6 +204,7 @@
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;

Expand Down Expand Up @@ -419,9 +421,52 @@ public void visit(BLangRecordTypeNode recordTypeNode) {
rewrite(field, recordTypeNode.typeDefEnv);
}
recordTypeNode.restFieldType = rewrite(recordTypeNode.restFieldType, env);
// In the current implementation, closures generated for default values in inclusions defined in a
// separate module are unidentifiable.
// Due to that, if the inclusions are in different modules, we generate closures again.
// Will be fixed with #41949 issue.
generateClosuresForDefaultValuesInTypeInclusionsFromDifferentModule(recordTypeNode);
result = recordTypeNode;
}

private List<String> getFieldNames(List<BLangSimpleVariable> fields) {
List<String> fieldNames = new ArrayList<>();
for (BLangSimpleVariable field : fields) {
fieldNames.add(field.name.getValue());
}
return fieldNames;
}

private void generateClosuresForDefaultValuesInTypeInclusionsFromDifferentModule(
BLangRecordTypeNode recordTypeNode) {
if (recordTypeNode.typeRefs.isEmpty()) {
return;
}
List<String> fieldNames = getFieldNames(recordTypeNode.fields);
BTypeSymbol typeSymbol = recordTypeNode.getBType().tsymbol;
String typeName = recordTypeNode.symbol.name.value;
PackageID packageID = typeSymbol.pkgID;
for (BLangType type : recordTypeNode.typeRefs) {
BType bType = type.getBType();
if (packageID.equals(bType.tsymbol.pkgID)) {
continue;
}
BRecordType recordType = (BRecordType) Types.getReferredType(bType);
Map<String, BInvokableSymbol> defaultValuesOfTypeRef =
((BRecordTypeSymbol) recordType.tsymbol).defaultValues;
for (Map.Entry<String, BInvokableSymbol> defaultValue : defaultValuesOfTypeRef.entrySet()) {
String name = defaultValue.getKey();
if (fieldNames.contains(name)) {
continue;
}
BInvokableSymbol symbol = defaultValue.getValue();
BLangInvocation invocation = getInvocation(symbol);
String closureName = RECORD_DELIMITER + typeName + RECORD_DELIMITER + name;
generateClosureForDefaultValues(closureName, name, invocation, symbol.retType, typeSymbol);
}
}
}

@Override
public void visit(BLangTupleTypeNode tupleTypeNode) {
BTypeSymbol typeSymbol = tupleTypeNode.getBType().tsymbol;
Expand Down Expand Up @@ -552,7 +597,7 @@ public void visit(BLangSimpleVariable varNode) {
return;
}

if (varNode.symbol != null && Symbols.isFlagOn(varNode.symbol.flags, Flags.DEFAULTABLE_PARAM)) {
if (Symbols.isFlagOn(varNode.symbol.flags, Flags.DEFAULTABLE_PARAM)) {
String closureName = generateName(varNode.symbol.name.value, env.node);
generateClosureForDefaultValues(closureName, varNode.name.value, varNode);
} else {
Expand All @@ -574,13 +619,18 @@ private BSymbol getOwner(SymbolEnv symbolEnv) {
}

private void generateClosureForDefaultValues(String closureName, String paramName, BLangSimpleVariable varNode) {
generateClosureForDefaultValues(closureName, paramName, varNode.expr, varNode.getBType(),
env.node.getBType().tsymbol);
}

private void generateClosureForDefaultValues(String closureName, String paramName, BLangExpression expr,
BType returnType, BTypeSymbol symbol) {
BSymbol owner = getOwner(env);
BLangFunction function = createFunction(closureName, varNode.pos, owner.pkgID, owner, varNode.getBType());
BLangFunction function = createFunction(closureName, expr.pos, owner.pkgID, owner, returnType);
BLangReturn returnStmt = ASTBuilderUtil.createReturnStmt(function.pos, (BLangBlockFunctionBody) function.body);
returnStmt.expr = types.addConversionExprIfRequired(varNode.expr, function.returnTypeNode.getBType());
returnStmt.expr = types.addConversionExprIfRequired(expr, function.returnTypeNode.getBType());
BLangLambdaFunction lambdaFunction = createLambdaFunction(function);
BInvokableSymbol varSymbol = createSimpleVariable(function, lambdaFunction, false);
BTypeSymbol symbol = env.node.getBType().tsymbol;
if (symbol.getKind() == SymbolKind.INVOKABLE_TYPE) {
BInvokableTypeSymbol invokableTypeSymbol = (BInvokableTypeSymbol) symbol;
updateFunctionParams(function, invokableTypeSymbol.params, paramName);
Expand Down Expand Up @@ -1053,7 +1103,8 @@ public void visit(BLangCompoundAssignment compoundAssignment) {
public void visit(BLangInvocation invocation) {
rewriteInvocationExpr(invocation);
BLangInvokableNode encInvokable = env.enclInvokable;
if (encInvokable == null || !invocation.functionPointerInvocation) {
if (encInvokable == null || !invocation.functionPointerInvocation ||
!env.enclPkg.packageID.equals(invocation.symbol.pkgID)) {
return;
}
updateClosureVariable((BVarSymbol) invocation.symbol, encInvokable, invocation.pos);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ public class FunctionSignatureInBalaTest {

@BeforeClass
public void setup() {
BCompileUtil.compileAndCacheBala("test-src/bala/test_projects/test_project_utils");
BCompileUtil.compileAndCacheBala("test-src/bala/test_projects/test_project_functions");
BCompileUtil.compileAndCacheBala("test-src/bala/test_projects/test_project");
compileResult =
BCompileUtil.compile("test-src/bala/test_bala/functions/test_different_function_signatures.bal");
Expand Down Expand Up @@ -365,6 +367,11 @@ public void testInvocationWithArgVarargMix() {
public void testCyclicFuncCallWhenFuncDefinedInModuleWithSameName() {
BRunUtil.invoke(compileResult, "testCyclicFuncCallWhenFuncDefinedInModuleWithSameName");
}

@Test
public void testFuncCallingFuncFromDifferentModuleAsParamDefault() {
BRunUtil.invoke(compileResult, "testFuncCallingFuncFromDifferentModuleAsParamDefault");
}

@Test
public void testNegativeFunctionInvocations() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public class ImmutabilityBalaTest {
public void setup() {
BCompileUtil.compileAndCacheBala("test-src/bala/test_projects/test_project_selectively_immutable");
BCompileUtil.compileAndCacheBala("test-src/bala/test_projects/test_project_immutable");
BCompileUtil.compileAndCacheBala("test-src/bala/test_projects/test_project_types");
BCompileUtil.compileAndCacheBala("test-src/bala/test_projects/test_project_records");
result = BCompileUtil.compile("test-src/bala/test_bala/readonly/test_selectively_immutable_type.bal");
inherentlyImmutableResult = BCompileUtil.compile(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public class OpenRecordTypeInclusionTest {

@BeforeClass
public void setup() {
BCompileUtil.compileAndCacheBala("test-src/bala/test_projects/test_project_types");
BCompileUtil.compileAndCacheBala("test-src/bala/test_projects/test_project_records");
compileResult = BCompileUtil.compile("test-src/record/open_record_type_inclusion.bal");
}
Expand Down Expand Up @@ -192,6 +193,7 @@ public Object[] testFunctions() {
"testCyclicRecord",
"testOutOfOrderFieldOverridingFieldFromTypeInclusion",
"testCreatingRecordWithOverriddenFields",
"testDefaultValuesOfRecordFieldsWithTypeInclusion"
};
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import testorg/foo as foo;
import testorg/functions;

//------------ Testing a function with all types of parameters ---------

Expand Down Expand Up @@ -85,6 +86,12 @@ function testInvokeFuncWithAnyRestParam1() returns any[] {
return foo:functionAnyRestParam(a, j);
}

// ------------------- Test function signature with invocation as default value of parameter

function testFuncCallingFuncFromDifferentModuleAsParamDefault() {
assertValueEquality(101, functions:funcCallingFuncFromDifferentModuleAsParamDefault());
}

// ------------------- Test function signature with union types for default parameter

function testFuncWithUnionTypedDefaultParam() returns json {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

type BooleanArray boolean[];
type StringArray string[];

Expand Down Expand Up @@ -46,7 +45,6 @@ public function functionAnyRestParam(any... z) returns any[] {
return z;
}


// ------------------- Test function signature with union types for default parameter

public function funcWithUnionTypedDefaultParam(string|int? s = "John") returns string|int? {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[package]
org = "testorg"
name = "functions"
version = "0.1.0"
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) 2024 WSO2 LLC. (http://www.wso2.org) All Rights Reserved.
//
// WSO2 LLC. licenses this file to you under the Apache License,
// Version 2.0 (the "License"); you may not use this file except
// in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

import testorg/utils;

// ------------------- Test function signature with call of function from a different module as parameter default

public function funcCallingFuncFromDifferentModuleAsParamDefault(int a = utils:foo()) returns int {
return a;
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
// specific language governing permissions and limitations
// under the License.

import testOrg/public_and_private_types;

public type BClosedPerson record {|
string name = "anonymous";
int age = 0;
Expand All @@ -39,3 +41,23 @@ public type ClosedVehicleWithNever record {|
int j;
never p?;
|};

public type BClosedStudent record {|
string name = "anonymous";
int age = 20;
|};

public type Info record {|
*public_and_private_types:Person;
|};

public type Info1 record {|
string name = "James";
*public_and_private_types:Person;
|};

public type Location record {
*public_and_private_types:Address;
string street = "abc";
int zipCode = 123;
};
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,13 @@ import public_and_private_types.parser;
public type ErrorDetail record {|
*parser:ErrorDetail;
|};

public type Person record {|
string name = "John";
int age = 30;
|};

public type Address record {|
string city;
string country;
|};
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ public function assertTrue(any|error actual) {
assertEquality(true, actual);
}

public function foo() returns int {
return 101;
}

public function assertEquality(any|error expected, any|error actual) {
if expected is anydata && actual is anydata && expected == actual {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,41 @@ type Baz record {
int id;
};

type Student record {
*records:BClosedStudent;
};

type Student1 record {
string name = "chiranS";
*records:BClosedStudent;
};

type Foo5 record {
*records:BClosedStudent;
*records:BClosedPerson;
string name = "chiranS";
int age = 25;
};

type Foo6 record {
*records:BClosedStudent;
*records:BClosedPerson;
string name;
int age;
};

type Info record {
*records:Info;
};

type Info1 record {
*records:Info1;
};

type Location record {|
*records:Location;
|};

function testOutOfOrderFieldOverridingFieldFromTypeInclusion() {
Baz bazRecord = {id: 4};
Bar barRecord = {body: bazRecord};
Expand All @@ -228,6 +263,32 @@ function testCyclicRecord() {
assertEquality(34, cc?.auth?.d1?.x);
}

function testDefaultValuesOfRecordFieldsWithTypeInclusion() {
Student student = {};
assertEquality("anonymous", student.name);
assertEquality(20, student.age);
Student1 student1 = {};
assertEquality("chiranS", student1.name);
assertEquality(20, student1.age);
Foo5 foo5 = {};
assertEquality("chiranS", foo5.name);
assertEquality(25, foo5.age);
Foo6 foo6 = {name: "sachintha", age: 28};
assertEquality("sachintha", foo6.name);
assertEquality(28, foo6.age);
Info info = {};
assertEquality("John", info.name);
assertEquality(30, info.age);
Location location = {city: "Colombo", country: "Sri Lanka"};
assertEquality("Colombo", location.city);
assertEquality("Sri Lanka", location.country);
assertEquality("abc", location.street);
assertEquality(123, location.zipCode);
Info1 info1 = {};
assertEquality("James", info1.name);
assertEquality(30, info1.age);
}

const ASSERTION_ERROR_REASON = "AssertionError";

function assertEquality(any|error expected, any|error actual) {
Expand Down

0 comments on commit 7782a44

Please sign in to comment.