Skip to content

Commit

Permalink
Merge pull request #28443 from ProvableHQ/feat/analyzers
Browse files Browse the repository at this point in the history
[Fix] Introduce `StaticAnalysis` pass and add checks on usage of async code for safety.
  • Loading branch information
d0cd authored Nov 18, 2024
2 parents 166a3c3 + ec67f3d commit bcdb1f5
Show file tree
Hide file tree
Showing 44 changed files with 1,154 additions and 200 deletions.
22 changes: 22 additions & 0 deletions compiler/ast/src/passes/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,28 @@

use crate::*;

// TODO: The Visitor and Reconstructor patterns need a redesign so that the default implementation can easily be invoked though its implemented in an overriding trait.
// Here is a pattern that seems to work
// trait ProgramVisitor {
// // The trait method that can be overridden
// fn visit_program_scope(&mut self);
//
// // Private helper function containing the default implementation
// fn default_visit_program_scope(&mut self) {
// println!("Do default stuff");
// }
// }
//
// struct YourStruct;
//
// impl ProgramVisitor for YourStruct {
// fn visit_program_scope(&mut self) {
// println!("Do custom stuff.");
// // Call the default implementation
// self.default_visit_program_scope();
// }
// }

/// A Visitor trait for expressions in the AST.
pub trait ExpressionVisitor<'a> {
type AdditionalInput: Default;
Expand Down
21 changes: 15 additions & 6 deletions compiler/compiler/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,24 @@ impl<'a, N: Network> Compiler<'a, N> {

/// Runs the type checker pass.
pub fn type_checker_pass(&'a self, symbol_table: SymbolTable) -> Result<(SymbolTable, StructGraph, CallGraph)> {
let (symbol_table, struct_graph, call_graph) = TypeChecker::<N>::do_pass((
let (symbol_table, struct_graph, call_graph) =
TypeChecker::<N>::do_pass((&self.ast, self.handler, symbol_table, &self.type_table))?;
if self.compiler_options.output.type_checked_symbol_table {
self.write_symbol_table_to_json("type_checked_symbol_table.json", &symbol_table)?;
}
Ok((symbol_table, struct_graph, call_graph))
}

/// Runs the static analysis pass.
pub fn static_analysis_pass(&mut self, symbol_table: &SymbolTable) -> Result<()> {
StaticAnalyzer::<N>::do_pass((
&self.ast,
self.handler,
symbol_table,
&self.type_table,
self.compiler_options.build.conditional_block_max_depth,
self.compiler_options.build.disable_conditional_branch_type_checking,
))?;
if self.compiler_options.output.type_checked_symbol_table {
self.write_symbol_table_to_json("type_checked_symbol_table.json", &symbol_table)?;
}
Ok((symbol_table, struct_graph, call_graph))
))
}

/// Runs the loop unrolling pass.
Expand Down Expand Up @@ -285,8 +291,11 @@ impl<'a, N: Network> Compiler<'a, N> {
/// Runs the compiler stages.
pub fn compiler_stages(&mut self) -> Result<(SymbolTable, StructGraph, CallGraph)> {
let st = self.symbol_table_pass()?;

let (st, struct_graph, call_graph) = self.type_checker_pass(st)?;

self.static_analysis_pass(&st)?;

// TODO: Make this pass optional.
let st = self.loop_unrolling_pass(st)?;

Expand Down
2 changes: 2 additions & 0 deletions compiler/compiler/tests/integration/utilities/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ pub fn compile_and_process<'a>(parsed: &'a mut Compiler<'a, CurrentNetwork>) ->

let (st, struct_graph, call_graph) = parsed.type_checker_pass(st)?;

parsed.static_analysis_pass(&st)?;

CheckUniqueNodeIds::new().visit_program(&parsed.ast.ast);

let st = parsed.loop_unrolling_pass(st)?;
Expand Down
13 changes: 12 additions & 1 deletion compiler/passes/src/common/tree_node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,20 @@ impl<N: Node> TreeNode<N> {
}

/// Removes an element from the current node.
pub fn remove_element(&mut self, element: &N) {
/// If the element does not exist, increment an internal counter which later used to generate an error that the user attempted to await a future twice.
/// Returns `true` if the element was removed but not the first one in the node.
pub fn remove_element(&mut self, element: &N) -> bool {
// Check if the element is the first one in the node.
let is_not_first = match self.elements.first() {
Some(first) => first != element,
None => false,
};
// Remove the element from the node.
if !self.elements.shift_remove(element) {
self.counter += 1;
false
} else {
is_not_first
}
}
}
3 changes: 3 additions & 0 deletions compiler/passes/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
#![forbid(unsafe_code)]
#![doc = include_str!("../README.md")]

pub mod static_analysis;
pub use static_analysis::*;

pub mod code_generation;
pub use code_generation::*;

Expand Down
1 change: 1 addition & 0 deletions compiler/passes/src/loop_unrolling/unroller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ impl<'a> Unroller<'a> {
.swap(previous_constant_propagation_table.borrow().lookup_scope_by_index(index).unwrap());
self.constant_propagation_table.borrow_mut().parent =
Some(Box::new(previous_constant_propagation_table.into_inner()));

core::mem::replace(&mut self.scope_index, 0)
}

Expand Down
67 changes: 67 additions & 0 deletions compiler/passes/src/static_analysis/analyze_expression.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (C) 2019-2024 Aleo Systems Inc.
// This file is part of the Leo library.

// The Leo library is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.

// The Leo library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.

use crate::StaticAnalyzer;

use leo_ast::*;

use snarkvm::console::network::Network;

impl<'a, N: Network> ExpressionVisitor<'a> for StaticAnalyzer<'a, N> {
type AdditionalInput = ();
type Output = ();

fn visit_access(&mut self, input: &'a AccessExpression, _: &Self::AdditionalInput) -> Self::Output {
if let AccessExpression::AssociatedFunction(access) = input {
// Get the core function.
let core_function = match CoreFunction::from_symbols(access.variant.name, access.name.name) {
Some(core_function) => core_function,
None => unreachable!("Typechecking guarantees that this function exists."),
};

// Check that the future was awaited correctly.
if core_function == CoreFunction::FutureAwait {
self.assert_future_await(&access.arguments.first(), input.span());
}
}
}

fn visit_call(&mut self, input: &'a CallExpression, _: &Self::AdditionalInput) -> Self::Output {
match &*input.function {
// Note that the parser guarantees that `input.function` is always an identifier.
Expression::Identifier(ident) => {
// If the function call is an external async transition, then for all async calls that follow a non-async call,
// we must check that the async call is not an async function that takes a future as an argument.
if self.non_async_external_call_seen
&& self.variant == Some(Variant::AsyncTransition)
&& input.program.is_some()
{
// Note that this unwrap is safe since we check that `input.program` is `Some` above.
self.assert_simple_async_transition_call(input.program.unwrap(), ident.name, input.span());
}
// Otherwise look up the function and check if it is a non-async call.
if let Some(function_symbol) =
self.symbol_table.lookup_fn_symbol(Location::new(input.program, ident.name))
{
if function_symbol.variant == Variant::Transition {
self.non_async_external_call_seen = true;
}
}
}
_ => unreachable!("Parsing guarantees that a function name is always an identifier."),
}
}
}
111 changes: 111 additions & 0 deletions compiler/passes/src/static_analysis/analyze_program.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright (C) 2019-2024 Aleo Systems Inc.
// This file is part of the Leo library.

// The Leo library is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.

// The Leo library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.

use crate::StaticAnalyzer;

use leo_ast::{Type, *};
use leo_errors::{StaticAnalyzerError, StaticAnalyzerWarning};

use snarkvm::console::network::Network;

impl<'a, N: Network> ProgramVisitor<'a> for StaticAnalyzer<'a, N> {
fn visit_program_scope(&mut self, input: &'a ProgramScope) {
// Set the current program name.
self.current_program = Some(input.program_id.name.name);
// Do the default implementation for visiting the program scope.
input.structs.iter().for_each(|(_, c)| (self.visit_struct(c)));
input.mappings.iter().for_each(|(_, c)| (self.visit_mapping(c)));
input.functions.iter().for_each(|(_, c)| (self.visit_function(c)));
input.consts.iter().for_each(|(_, c)| (self.visit_const(c)));
}

fn visit_function(&mut self, function: &'a Function) {
// Set the function name and variant.
self.variant = Some(function.variant);

// Set `non_async_external_call_seen` to false.
self.non_async_external_call_seen = false;

// If the function is an async function, initialize the await checker.
if self.variant == Some(Variant::AsyncFunction) {
// Initialize the list of input futures. Each one must be awaited before the end of the function.
self.await_checker.set_futures(
function
.input
.iter()
.filter_map(|input| {
if let Type::Future(_) = input.type_.clone() { Some(input.identifier.name) } else { None }
})
.collect(),
);
}

self.visit_block(&function.block);

// Check that all futures were awaited exactly once.
if self.variant == Some(Variant::AsyncFunction) {
// Throw error if not all futures awaits even appear once.
if !self.await_checker.static_to_await.is_empty() {
self.emit_err(StaticAnalyzerError::future_awaits_missing(
self.await_checker
.static_to_await
.clone()
.iter()
.map(|f| f.to_string())
.collect::<Vec<String>>()
.join(", "),
function.span(),
));
} else if self.await_checker.enabled && !self.await_checker.to_await.is_empty() {
// Tally up number of paths that are unawaited and number of paths that are awaited more than once.
let (num_paths_unawaited, num_paths_duplicate_awaited, num_perfect) =
self.await_checker.to_await.iter().fold((0, 0, 0), |(unawaited, duplicate, perfect), path| {
(
unawaited + if !path.elements.is_empty() { 1 } else { 0 },
duplicate + if path.counter > 0 { 1 } else { 0 },
perfect + if path.counter > 0 || !path.elements.is_empty() { 0 } else { 1 },
)
});

// Throw error if there does not exist a path in which all futures are awaited exactly once.
if num_perfect == 0 {
self.emit_err(StaticAnalyzerError::no_path_awaits_all_futures_exactly_once(
self.await_checker.to_await.len(),
function.span(),
));
}

// Throw warning if not all futures are awaited in some paths.
if num_paths_unawaited > 0 {
self.emit_warning(StaticAnalyzerWarning::some_paths_do_not_await_all_futures(
self.await_checker.to_await.len(),
num_paths_unawaited,
function.span(),
));
}

// Throw warning if some futures are awaited more than once in some paths.
if num_paths_duplicate_awaited > 0 {
self.emit_warning(StaticAnalyzerWarning::some_paths_contain_duplicate_future_awaits(
self.await_checker.to_await.len(),
num_paths_duplicate_awaited,
function.span(),
));
}
}
}
}
}
54 changes: 54 additions & 0 deletions compiler/passes/src/static_analysis/analyze_statement.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (C) 2019-2024 Aleo Systems Inc.
// This file is part of the Leo library.

// The Leo library is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.

// The Leo library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.

use super::*;
use crate::ConditionalTreeNode;

use leo_ast::*;

impl<'a, N: Network> StatementVisitor<'a> for StaticAnalyzer<'a, N> {
fn visit_conditional(&mut self, input: &'a ConditionalStatement) {
self.visit_expression(&input.condition, &Default::default());

// Create scope for checking awaits in `then` branch of conditional.
let current_bst_nodes: Vec<ConditionalTreeNode> =
match self.await_checker.create_then_scope(self.variant == Some(Variant::AsyncFunction), input.span) {
Ok(nodes) => nodes,
Err(warn) => return self.emit_warning(warn),
};

// Visit block.
self.visit_block(&input.then);

// Exit scope for checking awaits in `then` branch of conditional.
let saved_paths =
self.await_checker.exit_then_scope(self.variant == Some(Variant::AsyncFunction), current_bst_nodes);

if let Some(otherwise) = &input.otherwise {
match &**otherwise {
Statement::Block(stmt) => {
// Visit the otherwise-block.
self.visit_block(stmt);
}
Statement::Conditional(stmt) => self.visit_conditional(stmt),
_ => unreachable!("Else-case can only be a block or conditional statement."),
}
}

// Update the set of all possible BST paths.
self.await_checker.exit_statement_scope(self.variant == Some(Variant::AsyncFunction), saved_paths);
}
}
Loading

0 comments on commit bcdb1f5

Please sign in to comment.