diff --git a/.gitignore b/.gitignore index 41058537..f7e88f5c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ target .cargo +*.bak .vscode/* diff --git a/silverscript-lang/src/compiler.rs b/silverscript-lang/src/compiler.rs index bae71f78..74d1c3ee 100644 --- a/silverscript-lang/src/compiler.rs +++ b/silverscript-lang/src/compiler.rs @@ -24,6 +24,7 @@ use debug_value_types::infer_debug_expr_value_type; /// Prefix used for synthetic argument bindings during inline function expansion. pub const SYNTHETIC_ARG_PREFIX: &str = "__arg"; const COVENANT_POLICY_PREFIX: &str = "__covenant_policy"; +const HIDDEN_STACK_BINDING_PREFIX: &str = "__stack_shadow"; #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] pub struct CovenantDeclCallOptions { @@ -833,6 +834,18 @@ fn compile_contract_impl<'i>( } let lowered_contract = lower_covenant_declarations(contract, &constants)?; + let mut lowered_contract = lowered_contract; + let mut next_if_id = 0usize; + for function in &mut lowered_contract.functions { + let mut scope = HashMap::new(); + for field in &lowered_contract.fields { + scope.insert(field.name.clone(), field.type_ref.clone()); + } + for param in &function.params { + scope.insert(param.name.clone(), param.type_ref.clone()); + } + function.body = normalize_function_body_if_reassignments(&function.body, &mut scope, &mut next_if_id); + } let structs = build_struct_registry(&lowered_contract)?; validate_struct_graph(&structs)?; validate_contract_struct_usage(&lowered_contract, &structs)?; @@ -1461,12 +1474,16 @@ fn store_struct_binding<'i>( Ok(()) } +fn struct_assignment_temp_name(name: &str, path: &[String]) -> String { + format!("__struct_assign_{}__{}", name, path.join("__")) +} + #[allow(clippy::too_many_arguments)] fn push_struct_leaf_stack_bindings<'i>( name: &str, type_ref: &TypeRef, - env: &HashMap>, - assigned_names: &HashSet, + env: &mut HashMap>, + _assigned_names: &HashSet, identifier_uses: &HashMap, types: &HashMap, stack_bindings: &mut HashMap, @@ -1476,7 +1493,7 @@ fn push_struct_leaf_stack_bindings<'i>( script_size: Option, contract_constants: &HashMap>, ) -> Result, CompilerError> { - if assigned_names.contains(name) || identifier_uses.get(name).copied().unwrap_or(0) < 2 { + if identifier_uses.get(name).copied().unwrap_or(0) < 2 { return Ok(Vec::new()); } @@ -1510,6 +1527,7 @@ fn push_struct_leaf_stack_bindings<'i>( contract_constants, )?; push_stack_binding(stack_bindings, &leaf_name); + env.remove(&leaf_name); added.push(leaf_name); } @@ -1626,6 +1644,15 @@ fn collect_identifier_uses<'i>(statements: &[Statement<'i>]) -> HashMap(stmt: &Statement<'i>, remaining_uses: &mut HashMap) { + let mut stmt_uses = HashMap::new(); + collect_statement_identifier_uses(stmt, &mut stmt_uses); + for (name, count) in stmt_uses { + let entry = remaining_uses.entry(name).or_insert(0); + *entry = entry.saturating_sub(count); + } +} + fn bump_identifier_use(uses: &mut HashMap, name: &str) { *uses.entry(name.to_string()).or_insert(0) += 1; } @@ -1755,6 +1782,438 @@ fn collect_assigned_names_into<'i>(statements: &[Statement<'i>], assigned: &mut } } +fn collect_if_reassigned_outer_names<'i>( + statements: &[Statement<'i>], + outer_scope: &HashSet, + ordered: &mut Vec, + seen: &mut HashSet, +) { + for stmt in statements { + match stmt { + Statement::Assign { name, .. } if outer_scope.contains(name) && seen.insert(name.clone()) => { + ordered.push(name.clone()); + } + Statement::If { then_branch, else_branch, .. } => { + collect_if_reassigned_outer_names(then_branch, outer_scope, ordered, seen); + if let Some(else_branch) = else_branch { + collect_if_reassigned_outer_names(else_branch, outer_scope, ordered, seen); + } + } + Statement::For { body, .. } => collect_if_reassigned_outer_names(body, outer_scope, ordered, seen), + _ => {} + } + } +} + +fn collect_branch_assigned_names<'i>(statements: &[Statement<'i>], assigned: &mut HashSet) { + for stmt in statements { + match stmt { + Statement::Assign { name, .. } => { + assigned.insert(name.clone()); + } + Statement::If { then_branch, else_branch, .. } => { + collect_branch_assigned_names(then_branch, assigned); + if let Some(else_branch) = else_branch { + collect_branch_assigned_names(else_branch, assigned); + } + } + Statement::For { body, .. } => collect_branch_assigned_names(body, assigned), + _ => {} + } + } +} + +fn apply_identifier_aliases<'i>(expr: &Expr<'i>, aliases: &HashMap) -> Expr<'i> { + let mut rewritten = expr.clone(); + for (from, to) in aliases { + rewritten = replace_identifier(&rewritten, from, &Expr::identifier(to.clone())); + } + rewritten +} + +fn apply_statement_aliases<'i>(stmt: &Statement<'i>, aliases: &HashMap) -> Statement<'i> { + match stmt { + Statement::VariableDefinition { type_ref, modifiers, name, expr, span, type_span, modifier_spans, name_span } => { + Statement::VariableDefinition { + type_ref: type_ref.clone(), + modifiers: modifiers.clone(), + name: name.clone(), + expr: expr.as_ref().map(|expr| apply_identifier_aliases(expr, aliases)), + span: *span, + type_span: *type_span, + modifier_spans: modifier_spans.clone(), + name_span: *name_span, + } + } + Statement::TupleAssignment { + left_type_ref, + left_name, + right_type_ref, + right_name, + expr, + span, + left_type_span, + left_name_span, + right_type_span, + right_name_span, + } => Statement::TupleAssignment { + left_type_ref: left_type_ref.clone(), + left_name: left_name.clone(), + right_type_ref: right_type_ref.clone(), + right_name: right_name.clone(), + expr: apply_identifier_aliases(expr, aliases), + span: *span, + left_type_span: *left_type_span, + left_name_span: *left_name_span, + right_type_span: *right_type_span, + right_name_span: *right_name_span, + }, + Statement::ArrayPush { name, expr, span, name_span } => Statement::ArrayPush { + name: aliases.get(name).cloned().unwrap_or_else(|| name.clone()), + expr: apply_identifier_aliases(expr, aliases), + span: *span, + name_span: *name_span, + }, + Statement::FunctionCall { name, args, span, name_span } => Statement::FunctionCall { + name: name.clone(), + args: args.iter().map(|arg| apply_identifier_aliases(arg, aliases)).collect(), + span: *span, + name_span: *name_span, + }, + Statement::FunctionCallAssign { bindings, name, args, span, name_span } => Statement::FunctionCallAssign { + bindings: bindings.clone(), + name: name.clone(), + args: args.iter().map(|arg| apply_identifier_aliases(arg, aliases)).collect(), + span: *span, + name_span: *name_span, + }, + Statement::StateFunctionCallAssign { bindings, name, args, span, name_span } => Statement::StateFunctionCallAssign { + bindings: bindings.clone(), + name: name.clone(), + args: args.iter().map(|arg| apply_identifier_aliases(arg, aliases)).collect(), + span: *span, + name_span: *name_span, + }, + Statement::StructDestructure { bindings, expr, span } => { + Statement::StructDestructure { bindings: bindings.clone(), expr: apply_identifier_aliases(expr, aliases), span: *span } + } + Statement::Assign { name, expr, span, name_span } => Statement::Assign { + name: aliases.get(name).cloned().unwrap_or_else(|| name.clone()), + expr: apply_identifier_aliases(expr, aliases), + span: *span, + name_span: *name_span, + }, + Statement::TimeOp { tx_var, expr, message, span, tx_var_span, message_span } => Statement::TimeOp { + tx_var: *tx_var, + expr: apply_identifier_aliases(expr, aliases), + message: message.clone(), + span: *span, + tx_var_span: *tx_var_span, + message_span: *message_span, + }, + Statement::Require { expr, message, span, message_span } => Statement::Require { + expr: apply_identifier_aliases(expr, aliases), + message: message.clone(), + span: *span, + message_span: *message_span, + }, + Statement::If { condition, then_branch, else_branch, span, then_span, else_span } => Statement::If { + condition: apply_identifier_aliases(condition, aliases), + then_branch: then_branch.iter().map(|stmt| apply_statement_aliases(stmt, aliases)).collect(), + else_branch: else_branch.as_ref().map(|branch| branch.iter().map(|stmt| apply_statement_aliases(stmt, aliases)).collect()), + span: *span, + then_span: *then_span, + else_span: *else_span, + }, + Statement::For { ident, start, end, max_iterations, body, span, ident_span, body_span } => Statement::For { + ident: ident.clone(), + start: apply_identifier_aliases(start, aliases), + end: apply_identifier_aliases(end, aliases), + max_iterations: apply_identifier_aliases(max_iterations, aliases), + body: body.iter().map(|stmt| apply_statement_aliases(stmt, aliases)).collect(), + span: *span, + ident_span: *ident_span, + body_span: *body_span, + }, + Statement::Return { exprs, span } => { + Statement::Return { exprs: exprs.iter().map(|expr| apply_identifier_aliases(expr, aliases)).collect(), span: *span } + } + Statement::Console { args, span } => Statement::Console { + args: args + .iter() + .map(|arg| match arg { + crate::ast::ConsoleArg::Identifier(name, arg_span) => { + crate::ast::ConsoleArg::Identifier(aliases.get(name).cloned().unwrap_or_else(|| name.clone()), *arg_span) + } + crate::ast::ConsoleArg::Literal(expr) => crate::ast::ConsoleArg::Literal(apply_identifier_aliases(expr, aliases)), + }) + .collect(), + span: *span, + }, + } +} + +fn synthetic_if_binding_name(name: &str, next_if_id: &mut usize) -> String { + let generated = format!("__if_{name}_{}", *next_if_id); + *next_if_id += 1; + generated +} + +fn original_name_from_synthetic_if_binding(name: &str) -> Option { + name.strip_prefix("__if_").and_then(|rest| rest.rsplit_once('_')).map(|(original, _)| original.to_string()) +} + +fn normalize_function_body_if_reassignments<'i>( + statements: &[Statement<'i>], + scope: &mut HashMap, + next_if_id: &mut usize, +) -> Vec> { + let mut normalized = Vec::with_capacity(statements.len()); + for stmt in statements { + let rewritten = normalize_statement_if_reassignments(stmt, scope, next_if_id); + update_scope_with_statement(&rewritten, scope); + normalized.push(rewritten); + } + normalized +} + +fn normalize_statement_if_reassignments<'i>( + stmt: &Statement<'i>, + scope: &HashMap, + next_if_id: &mut usize, +) -> Statement<'i> { + match stmt { + Statement::If { condition, then_branch, else_branch, span, then_span, else_span } => normalize_if_statement_reassignments( + condition, + then_branch, + else_branch.as_deref(), + *span, + *then_span, + *else_span, + scope, + next_if_id, + ), + Statement::For { ident, start, end, max_iterations, body, span, ident_span, body_span } => { + let mut body_scope = scope.clone(); + body_scope.insert(ident.clone(), parse_type_ref("int").expect("int type should parse")); + Statement::For { + ident: ident.clone(), + start: start.clone(), + end: end.clone(), + max_iterations: max_iterations.clone(), + body: normalize_function_body_if_reassignments(body, &mut body_scope, next_if_id), + span: *span, + ident_span: *ident_span, + body_span: *body_span, + } + } + _ => stmt.clone(), + } +} + +fn normalize_if_statement_reassignments<'i>( + condition: &Expr<'i>, + then_branch: &[Statement<'i>], + else_branch: Option<&[Statement<'i>]>, + span: span::Span<'i>, + then_span: span::Span<'i>, + else_span: Option>, + scope: &HashMap, + next_if_id: &mut usize, +) -> Statement<'i> { + let mut then_scope = scope.clone(); + let mut normalized_then = normalize_function_body_if_reassignments(then_branch, &mut then_scope, next_if_id); + let mut else_scope = scope.clone(); + let mut normalized_else = normalize_function_body_if_reassignments(else_branch.unwrap_or(&[]), &mut else_scope, next_if_id); + + let outer_scope = scope.keys().cloned().collect::>(); + let mut our_vars_reassigned = Vec::new(); + let mut seen = HashSet::new(); + collect_if_reassigned_outer_names(&normalized_then, &outer_scope, &mut our_vars_reassigned, &mut seen); + collect_if_reassigned_outer_names(&normalized_else, &outer_scope, &mut our_vars_reassigned, &mut seen); + + if our_vars_reassigned.is_empty() { + return Statement::If { + condition: condition.clone(), + then_branch: normalized_then, + else_branch: else_branch.map(|_| normalized_else), + span, + then_span, + else_span, + }; + } + + let mut then_assigned = HashSet::new(); + collect_branch_assigned_names(&normalized_then, &mut then_assigned); + let mut else_assigned = HashSet::new(); + collect_branch_assigned_names(&normalized_else, &mut else_assigned); + + for name in &our_vars_reassigned { + if !then_assigned.contains(name) { + normalized_then.push(Statement::Assign { + name: name.clone(), + expr: Expr::identifier(name.clone()), + span, + name_span: span, + }); + } + if !else_assigned.contains(name) { + normalized_else.push(Statement::Assign { + name: name.clone(), + expr: Expr::identifier(name.clone()), + span, + name_span: span, + }); + } + } + + let targeted = our_vars_reassigned.iter().cloned().collect::>(); + let synthetic_names = + our_vars_reassigned.iter().map(|name| (name.clone(), synthetic_if_binding_name(name, next_if_id))).collect::>(); + + let mut rewrite_then_scope = scope.clone(); + let rewritten_then = rewrite_if_branch_reassignments( + &normalized_then, + &mut rewrite_then_scope, + &our_vars_reassigned, + &targeted, + &synthetic_names, + scope, + next_if_id, + ); + + let mut rewrite_else_scope = scope.clone(); + let rewritten_else = rewrite_if_branch_reassignments( + &normalized_else, + &mut rewrite_else_scope, + &our_vars_reassigned, + &targeted, + &synthetic_names, + scope, + next_if_id, + ); + + Statement::If { + condition: condition.clone(), + then_branch: rewritten_then, + else_branch: Some(rewritten_else), + span, + then_span, + else_span: Some(else_span.unwrap_or(span)), + } +} + +fn rewrite_if_branch_reassignments<'i>( + statements: &[Statement<'i>], + scope: &mut HashMap, + ordered_targets: &[String], + targeted: &HashSet, + synthetic_names: &HashMap, + outer_scope: &HashMap, + next_if_id: &mut usize, +) -> Vec> { + let mut rewritten = Vec::new(); + let mut aliases = HashMap::new(); + + for stmt in statements { + let aliased_stmt = apply_statement_aliases(stmt, &aliases); + match aliased_stmt { + Statement::Assign { name, expr, span, name_span } if targeted.contains(&name) => { + let synthetic_name = synthetic_names.get(&name).cloned().expect("synthetic binding for reassigned if variable"); + let type_ref = outer_scope.get(&name).cloned().expect("outer if variable type"); + if aliases.contains_key(&name) { + rewritten.push(Statement::Assign { name: synthetic_name, expr, span, name_span }); + } else { + aliases.insert(name.clone(), synthetic_name.clone()); + scope.insert(synthetic_name.clone(), type_ref.clone()); + rewritten.push(Statement::VariableDefinition { + type_ref, + modifiers: Vec::new(), + name: synthetic_name, + expr: Some(expr), + span, + type_span: span, + modifier_spans: Vec::new(), + name_span, + }); + } + } + Statement::If { condition, then_branch, else_branch, span, then_span, else_span } => { + rewritten.push(normalize_if_statement_reassignments( + &condition, + &then_branch, + else_branch.as_deref(), + span, + then_span, + else_span, + scope, + next_if_id, + )); + } + other => { + update_scope_with_statement(&other, scope); + rewritten.push(other); + } + } + } + + for original in ordered_targets { + let synthetic = synthetic_names.get(original).cloned().expect("synthetic binding"); + if !aliases.contains_key(original) { + let type_ref = outer_scope.get(original).cloned().expect("outer if variable type"); + rewritten.push(Statement::VariableDefinition { + type_ref, + modifiers: Vec::new(), + name: synthetic.clone(), + expr: Some(Expr::identifier(original.clone())), + span: span::Span::default(), + type_span: span::Span::default(), + modifier_spans: Vec::new(), + name_span: span::Span::default(), + }); + } + rewritten.push(Statement::Assign { + name: original.clone(), + expr: Expr::identifier(synthetic), + span: span::Span::default(), + name_span: span::Span::default(), + }); + } + + rewritten +} + +fn update_scope_with_statement<'i>(stmt: &Statement<'i>, scope: &mut HashMap) { + match stmt { + Statement::VariableDefinition { type_ref, name, .. } => { + scope.insert(name.clone(), type_ref.clone()); + } + Statement::TupleAssignment { left_type_ref, left_name, right_type_ref, right_name, .. } => { + scope.insert(left_name.clone(), left_type_ref.clone()); + scope.insert(right_name.clone(), right_type_ref.clone()); + } + Statement::FunctionCallAssign { bindings, .. } => { + for binding in bindings { + scope.insert(binding.name.clone(), binding.type_ref.clone()); + } + } + Statement::StateFunctionCallAssign { bindings, .. } | Statement::StructDestructure { bindings, .. } => { + for binding in bindings { + scope.insert(binding.name.clone(), binding.type_ref.clone()); + } + } + Statement::If { .. } + | Statement::Assign { .. } + | Statement::ArrayPush { .. } + | Statement::FunctionCall { .. } + | Statement::TimeOp { .. } + | Statement::Require { .. } + | Statement::For { .. } + | Statement::Return { .. } + | Statement::Console { .. } => {} + } +} + fn push_stack_binding(bindings: &mut HashMap, name: &str) { for depth in bindings.values_mut() { *depth += 1; @@ -1762,16 +2221,56 @@ fn push_stack_binding(bindings: &mut HashMap, name: &str) { bindings.insert(name.to_string(), 0); } -fn pop_stack_bindings(bindings: &mut HashMap, names: &[String]) { - if names.is_empty() { - return; +fn rebind_stack_binding(bindings: &mut HashMap, name: &str) { + if let Some(previous_depth) = bindings.get(name).copied() { + let mut shadow_index = 0usize; + loop { + let shadow_name = format!("{HIDDEN_STACK_BINDING_PREFIX}_{name}_{shadow_index}"); + if let std::collections::hash_map::Entry::Vacant(e) = bindings.entry(shadow_name) { + e.insert(previous_depth); + break; + } + shadow_index += 1; + } } + push_stack_binding(bindings, name); +} - for name in names { - bindings.remove(name); +fn drop_stack_binding(bindings: &mut HashMap, name: &str, builder: &mut ScriptBuilder) -> Result<(), CompilerError> { + let Some(removed_depth) = bindings.remove(name) else { + return Ok(()); + }; + + if removed_depth < 0 { + return Err(CompilerError::Unsupported(format!("invalid negative stack depth for binding '{name}'"))); } + + if removed_depth == 0 { + builder.add_op(OpDrop)?; + } else { + builder.add_i64(removed_depth)?; + builder.add_op(OpRoll)?; + builder.add_op(OpDrop)?; + } + for depth in bindings.values_mut() { - *depth -= names.len() as i64; + if *depth > removed_depth { + *depth -= 1; + } + } + + Ok(()) +} + +fn consume_stack_binding(bindings: &mut HashMap, name: &str) { + let Some(removed_depth) = bindings.remove(name) else { + return; + }; + + for depth in bindings.values_mut() { + if *depth > removed_depth { + *depth -= 1; + } } } @@ -2335,8 +2834,6 @@ fn compile_entrypoint_function<'i>( .enumerate() .map(|(index, name)| (name.clone(), (contract_field_count + (param_count - 1 - index)) as i64)) .collect::>(); - let initial_stack_binding_count = stack_bindings.len() + contract_field_count; - for (index, field) in contract_fields.iter().enumerate() { stack_bindings.insert(field.name.clone(), (contract_field_count - 1 - index) as i64); } @@ -2399,6 +2896,7 @@ fn compile_entrypoint_function<'i>( recorder.begin_entrypoint(&function.name, function, contract_fields); let body_len = function.body.len(); + let mut remaining_identifier_uses = identifier_uses.clone(); for (index, stmt) in function.body.iter().enumerate() { recorder.begin_statement_at(builder.script().len(), &env, &stack_bindings); if let Statement::Return { exprs, .. } = stmt { @@ -2425,6 +2923,7 @@ fn compile_entrypoint_function<'i>( &mut env, &assigned_names, &identifier_uses, + &remaining_identifier_uses, &mut types, &mut stack_bindings, &mut builder, @@ -2442,6 +2941,7 @@ fn compile_entrypoint_function<'i>( .map_err(|err| err.with_span(&stmt.span()))?; } recorder.finish_statement_at(stmt, builder.script().len(), &env, &types, &stack_bindings)?; + consume_statement_identifier_uses(stmt, &mut remaining_identifier_uses); } let flattened_returns = if has_return { @@ -2458,27 +2958,31 @@ fn compile_entrypoint_function<'i>( Vec::new() }; + let remaining_param_bindings = flattened_param_names.iter().filter(|name| stack_bindings.contains_key(*name)).count(); + let remaining_contract_field_bindings = contract_fields.iter().filter(|field| stack_bindings.contains_key(&field.name)).count(); + let remaining_local_bindings = stack_bindings.len().saturating_sub(remaining_param_bindings + remaining_contract_field_bindings); + let return_count = flattened_returns.len(); if return_count == 0 { - for _ in 0..stack_bindings.len().saturating_sub(initial_stack_binding_count) { + for _ in 0..remaining_local_bindings { builder.add_i64(return_count as i64)?; builder.add_op(OpRoll)?; builder.add_op(OpDrop)?; } - for _ in 0..param_count { + for _ in 0..remaining_param_bindings { builder.add_op(OpDrop)?; } - for _ in 0..contract_field_count { + for _ in 0..remaining_contract_field_bindings { builder.add_op(OpDrop)?; } builder.add_op(OpTrue)?; } else { let mut stack_depth = 0i64; for expr in &flattened_returns { - compile_expr( + compile_tracked_expr( expr, &env, - &stack_bindings, + &mut stack_bindings, &types, &mut builder, options, @@ -2488,17 +2992,17 @@ fn compile_entrypoint_function<'i>( constants, )?; } - for _ in 0..stack_bindings.len().saturating_sub(initial_stack_binding_count) { + for _ in 0..remaining_local_bindings { builder.add_i64(return_count as i64)?; builder.add_op(OpRoll)?; builder.add_op(OpDrop)?; } - for _ in 0..param_count { + for _ in 0..remaining_param_bindings { builder.add_i64(return_count as i64)?; builder.add_op(OpRoll)?; builder.add_op(OpDrop)?; } - for _ in 0..contract_field_count { + for _ in 0..remaining_contract_field_bindings { builder.add_i64(return_count as i64)?; builder.add_op(OpRoll)?; builder.add_op(OpDrop)?; @@ -2515,6 +3019,7 @@ fn compile_statement<'i>( env: &mut HashMap>, assigned_names: &HashSet, identifier_uses: &HashMap, + remaining_identifier_uses: &HashMap, types: &mut HashMap, stack_bindings: &mut HashMap, builder: &mut ScriptBuilder, @@ -2690,15 +3195,26 @@ fn compile_statement<'i>( } types.insert(name.clone(), effective_type_name.clone()); let existing_is_predeclared_default = is_predeclared_scalar_default(name, &effective_type_name, env); - - if !assigned_names.contains(name) - && identifier_uses.get(name).copied().unwrap_or(0) >= 2 - && (!env.contains_key(name) || existing_is_predeclared_default) - && !stack_bindings.contains_key(name) - && matches!(effective_type_name.as_str(), "int" | "bool" | "byte") + let force_stack_materialization = name.starts_with("__if_") + && matches!(&expr.kind, ExprKind::Identifier(identifier) if stack_bindings.contains_key(identifier)); + + if force_stack_materialization + || (!assigned_names.contains(name) + && identifier_uses.get(name).copied().unwrap_or(0) >= 2 + && (!env.contains_key(name) || existing_is_predeclared_default) + && !stack_bindings.contains_key(name) + && matches!(effective_type_name.as_str(), "int" | "bool" | "byte")) { + let mut forced_last_uses = HashMap::new(); + if let Some(original_name) = original_name_from_synthetic_if_binding(name) { + if stack_bindings.contains_key(&original_name) { + let mut expr_uses = HashMap::new(); + collect_expr_identifier_uses(&expr, &mut expr_uses); + forced_last_uses.insert(original_name.clone(), expr_uses.remove(&original_name).unwrap_or(1)); + } + } let mut stack_depth = 0i64; - compile_expr( + compile_tracked_expr_with_forced_last_uses( &expr, env, stack_bindings, @@ -2709,6 +3225,7 @@ fn compile_statement<'i>( &mut stack_depth, script_size, contract_constants, + forced_last_uses, )?; env.insert(name.clone(), expr); push_stack_binding(stack_bindings, name); @@ -2825,9 +3342,18 @@ fn compile_statement<'i>( Ok(Vec::new()) } Statement::Require { expr, .. } => { + let mut expr_uses = HashMap::new(); + collect_expr_identifier_uses(expr, &mut expr_uses); let expr = lower_runtime_expr(expr, types, structs)?; + let forced_last_uses = expr_uses + .into_iter() + .filter(|(name, _)| stack_bindings.contains_key(name)) + .filter(|(name, _)| matches!(types.get(name).map(String::as_str), Some("int" | "bool" | "byte"))) + .filter(|(_, count)| *count == 1) + .filter(|(name, count)| remaining_identifier_uses.get(name).copied().unwrap_or(0) == *count) + .collect::>(); let mut stack_depth = 0i64; - compile_expr( + compile_tracked_expr_with_forced_last_uses( &expr, env, stack_bindings, @@ -2838,6 +3364,7 @@ fn compile_statement<'i>( &mut stack_depth, script_size, contract_constants, + forced_last_uses, )?; builder.add_op(OpVerify)?; Ok(Vec::new()) @@ -3041,6 +3568,7 @@ fn compile_statement<'i>( env, assigned_names, identifier_uses, + remaining_identifier_uses, types, stack_bindings, builder, @@ -3141,7 +3669,7 @@ fn compile_statement<'i>( && matches!(binding_type_name.as_str(), "int" | "bool" | "byte") { let mut stack_depth = 0i64; - compile_expr( + compile_tracked_expr( &lowered, env, stack_bindings, @@ -3163,12 +3691,65 @@ fn compile_statement<'i>( } Ok(added_stack_locals) } - Statement::Assign { name, expr, .. } => { + Statement::Assign { name, expr, span, name_span } => { if let Some(type_name) = types.get(name) { let expected_type_ref = parse_type_ref(type_name)?; - if struct_name_from_type_ref(&expected_type_ref, structs).is_some() - || struct_array_name_from_type_ref(&expected_type_ref, structs).is_some() - { + if struct_name_from_type_ref(&expected_type_ref, structs).is_some() { + let lowered_values = lower_runtime_struct_expr( + expr, + &expected_type_ref, + types, + structs, + contract_fields, + contract_constants, + contract_field_prefix_len, + )?; + let leaf_bindings = flatten_type_ref_leaves(&expected_type_ref, structs)?; + let original_env = env.clone(); + let mut temps = Vec::with_capacity(leaf_bindings.len()); + + types.insert(name.clone(), type_name.clone()); + for ((path, field_type), lowered_expr) in leaf_bindings.into_iter().zip(lowered_values.into_iter()) { + let leaf_name = flattened_struct_name(name, &path); + let temp_name = struct_assignment_temp_name(name, &path); + let resolved_temp = resolve_expr_for_runtime(lowered_expr, &original_env, types, &mut HashSet::new())?; + types.insert(temp_name.clone(), type_name_from_ref(&field_type)); + env.insert(temp_name.clone(), resolved_temp); + temps.push((leaf_name, temp_name)); + } + + for (leaf_name, temp_name) in temps { + compile_statement( + &Statement::Assign { + name: leaf_name, + expr: Expr::identifier(temp_name.clone()), + span: *span, + name_span: *name_span, + }, + env, + assigned_names, + identifier_uses, + remaining_identifier_uses, + types, + stack_bindings, + builder, + options, + contract_fields, + contract_field_prefix_len, + contract_constants, + structs, + functions, + function_order, + function_index, + script_size, + recorder, + )?; + env.remove(&temp_name); + types.remove(&temp_name); + } + return Ok(Vec::new()); + } + if struct_array_name_from_type_ref(&expected_type_ref, structs).is_some() { return store_struct_binding( name, &expected_type_ref, @@ -3213,17 +3794,78 @@ fn compile_statement<'i>( .unwrap_or_default() ))); } + let should_rebind_stack = !name.starts_with("__if_") + && (stack_bindings.contains_key(name) + || matches!(&lowered_expr.kind, ExprKind::Identifier(identifier) if identifier.starts_with("__if_"))); + if should_rebind_stack { + let mut forced_last_uses = HashMap::new(); + if stack_bindings.contains_key(name) { + let mut expr_uses = HashMap::new(); + collect_expr_identifier_uses(&lowered_expr, &mut expr_uses); + forced_last_uses.insert(name.clone(), expr_uses.remove(name).unwrap_or(1)); + } + let mut stack_depth = 0i64; + compile_tracked_expr_with_forced_last_uses( + &lowered_expr, + env, + stack_bindings, + types, + builder, + options, + &mut HashSet::new(), + &mut stack_depth, + script_size, + contract_constants, + forced_last_uses, + )?; + rebind_stack_binding(stack_bindings, name); + let resolved = resolve_expr_for_runtime(lowered_expr, env, types, &mut HashSet::new())?; + env.insert(name.clone(), resolved); + } else { + let updated = if let Some(previous) = env.get(name) { + replace_identifier(&lowered_expr, name, previous) + } else { + lowered_expr + }; + let resolved = resolve_expr_for_runtime(updated, env, types, &mut HashSet::new())?; + env.insert(name.clone(), resolved); + } + return Ok(Vec::new()); + } + let lowered_expr = lower_runtime_expr(expr, types, structs)?; + let should_rebind_stack = !name.starts_with("__if_") + && (stack_bindings.contains_key(name) + || matches!(&lowered_expr.kind, ExprKind::Identifier(identifier) if identifier.starts_with("__if_"))); + if should_rebind_stack { + let mut forced_last_uses = HashMap::new(); + if stack_bindings.contains_key(name) { + let mut expr_uses = HashMap::new(); + collect_expr_identifier_uses(&lowered_expr, &mut expr_uses); + forced_last_uses.insert(name.clone(), expr_uses.remove(name).unwrap_or(1)); + } + let mut stack_depth = 0i64; + compile_tracked_expr_with_forced_last_uses( + &lowered_expr, + env, + stack_bindings, + types, + builder, + options, + &mut HashSet::new(), + &mut stack_depth, + script_size, + contract_constants, + forced_last_uses, + )?; + rebind_stack_binding(stack_bindings, name); + let resolved = resolve_expr_for_runtime(lowered_expr, env, types, &mut HashSet::new())?; + env.insert(name.clone(), resolved); + } else { let updated = if let Some(previous) = env.get(name) { replace_identifier(&lowered_expr, name, previous) } else { lowered_expr }; let resolved = resolve_expr_for_runtime(updated, env, types, &mut HashSet::new())?; env.insert(name.clone(), resolved); - return Ok(Vec::new()); } - let lowered_expr = lower_runtime_expr(expr, types, structs)?; - let updated = - if let Some(previous) = env.get(name) { replace_identifier(&lowered_expr, name, previous) } else { lowered_expr }; - let resolved = resolve_expr_for_runtime(updated, env, types, &mut HashSet::new())?; - env.insert(name.clone(), resolved); Ok(Vec::new()) } Statement::Console { .. } => Ok(Vec::new()), @@ -3769,6 +4411,7 @@ fn compile_inline_call<'i>( let mut returns: Vec> = Vec::new(); let initial_stack_binding_count = bindings.stack_bindings.len(); + let mut remaining_identifier_uses = identifier_uses.clone(); for param in &function.params { let param_type_name = type_name_from_ref(¶m.type_ref); if !matches!(param_type_name.as_str(), "int" | "bool" | "byte") @@ -3827,6 +4470,7 @@ fn compile_inline_call<'i>( &mut bindings.env, &assigned_names, &identifier_uses, + &remaining_identifier_uses, &mut bindings.types, &mut bindings.stack_bindings, builder, @@ -3844,6 +4488,7 @@ fn compile_inline_call<'i>( .map_err(|err| err.with_span(&stmt.span()))?; } recorder.finish_statement_at(stmt, builder.script().len(), &bindings.env, &bindings.types, &bindings.stack_bindings)?; + consume_statement_identifier_uses(stmt, &mut remaining_identifier_uses); } for _ in 0..bindings.stack_bindings.len().saturating_sub(initial_stack_binding_count) { @@ -3878,6 +4523,10 @@ fn compile_if_statement<'i>( recorder: &mut DebugRecorder<'i>, ) -> Result<(), CompilerError> { let condition = lower_runtime_expr(condition, types, structs)?; + let original_env = env.clone(); + let original_stack_bindings = stack_bindings.clone(); + let visible_names = original_env.keys().chain(original_stack_bindings.keys()).cloned().collect::>(); + let mut stack_depth = 0i64; compile_expr( &condition, @@ -3893,9 +4542,9 @@ fn compile_if_statement<'i>( )?; builder.add_op(OpIf)?; - let original_env = env.clone(); let mut then_env = original_env.clone(); let mut then_types = types.clone(); + let mut then_stack_bindings = original_stack_bindings.clone(); predeclare_if_branch_locals(then_branch, &mut then_env, &mut then_types, structs)?; compile_block( then_branch, @@ -3903,7 +4552,7 @@ fn compile_if_statement<'i>( assigned_names, identifier_uses, &mut then_types, - stack_bindings, + &mut then_stack_bindings, builder, options, contract_fields, @@ -3918,10 +4567,12 @@ fn compile_if_statement<'i>( recorder, )?; + builder.add_op(OpElse)?; + let mut else_env = original_env.clone(); + let mut else_types = types.clone(); + let mut else_stack_bindings = original_stack_bindings.clone(); if let Some(else_branch) = else_branch { - builder.add_op(OpElse)?; - let mut else_types = types.clone(); predeclare_if_branch_locals(else_branch, &mut else_env, &mut else_types, structs)?; compile_block( else_branch, @@ -3929,7 +4580,7 @@ fn compile_if_statement<'i>( assigned_names, identifier_uses, &mut else_types, - stack_bindings, + &mut else_stack_bindings, builder, options, contract_fields, @@ -3947,11 +4598,70 @@ fn compile_if_statement<'i>( builder.add_op(OpEndIf)?; + let branch_only_binding_count = count_branch_only_stack_bindings_after_if( + &visible_names, + &original_stack_bindings, + &then_stack_bindings, + &else_stack_bindings, + ); + for _ in 0..branch_only_binding_count { + builder.add_op(OpDrop)?; + } + let resolved_condition = resolve_expr_for_runtime(condition, &original_env, types, &mut HashSet::new())?; merge_env_after_if(env, &original_env, &then_env, &else_env, &resolved_condition); + merge_types_after_if(types, &then_types, &else_types); + *stack_bindings = merge_stack_bindings_after_if( + &visible_names, + &original_stack_bindings, + &then_stack_bindings, + &else_stack_bindings, + branch_only_binding_count as i64, + ); Ok(()) } +fn count_branch_only_stack_bindings_after_if( + visible_names: &HashSet, + original_stack_bindings: &HashMap, + then_stack_bindings: &HashMap, + else_stack_bindings: &HashMap, +) -> usize { + let names = then_stack_bindings.keys().chain(else_stack_bindings.keys()).cloned().collect::>(); + names + .into_iter() + .filter(|name| !original_stack_bindings.contains_key(name)) + .filter(|name| !visible_names.contains(name)) + .filter(|name| matches!((then_stack_bindings.get(name), else_stack_bindings.get(name)), (Some(then_depth), Some(else_depth)) if then_depth == else_depth)) + .count() +} + +fn merge_stack_bindings_after_if( + visible_names: &HashSet, + original_stack_bindings: &HashMap, + then_stack_bindings: &HashMap, + else_stack_bindings: &HashMap, + dropped_binding_count: i64, +) -> HashMap { + let mut merged = HashMap::new(); + for name in visible_names { + match (then_stack_bindings.get(name), else_stack_bindings.get(name)) { + (Some(then_depth), Some(else_depth)) if then_depth == else_depth => { + merged.insert(name.clone(), *then_depth - dropped_binding_count); + } + (Some(then_depth), None) if original_stack_bindings.get(name) == Some(then_depth) => { + merged.insert(name.clone(), *then_depth); + } + (None, Some(else_depth)) if original_stack_bindings.get(name) == Some(else_depth) => { + merged.insert(name.clone(), *else_depth); + } + (None, None) => {} + _ => {} + } + } + merged +} + fn merge_env_after_if<'i>( env: &mut HashMap>, original_env: &HashMap>, @@ -3981,6 +4691,22 @@ fn merge_env_after_if<'i>( } } +fn merge_types_after_if( + types: &mut HashMap, + then_types: &HashMap, + else_types: &HashMap, +) { + let names = then_types.keys().chain(else_types.keys()).cloned().collect::>(); + for name in names { + match (then_types.get(&name), else_types.get(&name)) { + (Some(then_type), Some(else_type)) if then_type == else_type => { + types.insert(name, then_type.clone()); + } + _ => {} + } + } +} + fn default_scalar_expr(type_name: &str) -> Option> { match type_name { "int" => Some(Expr::int(0)), @@ -4103,6 +4829,8 @@ fn compile_block<'i>( recorder: &mut DebugRecorder<'i>, ) -> Result<(), CompilerError> { let mut added_stack_locals = Vec::new(); + let original_stack_binding_names = + if scoped_stack_locals { Some(stack_bindings.keys().cloned().collect::>()) } else { None }; for stmt in statements { recorder.begin_statement_at(builder.script().len(), env, stack_bindings); added_stack_locals.extend( @@ -4111,6 +4839,7 @@ fn compile_block<'i>( env, assigned_names, identifier_uses, + identifier_uses, types, stack_bindings, builder, @@ -4131,13 +4860,26 @@ fn compile_block<'i>( } if scoped_stack_locals && !added_stack_locals.is_empty() { - for _ in 0..added_stack_locals.len() { - builder.add_op(OpDrop)?; + for name in &added_stack_locals { + drop_stack_binding(stack_bindings, name, builder)?; } for name in &added_stack_locals { - types.remove(name); + if !name.starts_with("__if_") { + types.remove(name); + } + } + } + + if let Some(original_stack_binding_names) = original_stack_binding_names { + let added_hidden_stack_bindings = stack_bindings + .keys() + .filter(|name| name.starts_with(HIDDEN_STACK_BINDING_PREFIX)) + .filter(|name| !original_stack_binding_names.contains(*name)) + .cloned() + .collect::>(); + for name in added_hidden_stack_bindings { + drop_stack_binding(stack_bindings, &name, builder)?; } - pop_stack_bindings(stack_bindings, &added_stack_locals); } Ok(()) @@ -4865,6 +5607,46 @@ struct CompilationScope<'a, 'i> { types: &'a HashMap, } +#[derive(Default)] +struct ExprTrackingState { + consumed_stack_bindings: Vec, + forced_last_uses: HashMap, +} + +impl ExprTrackingState { + fn current_stack_binding_index(&self, name: &str, stack_bindings: &HashMap) -> Option { + let original = *stack_bindings.get(name)?; + let removed_above = self + .consumed_stack_bindings + .iter() + .filter_map(|consumed_name| stack_bindings.get(consumed_name)) + .filter(|depth| **depth < original) + .count() as i64; + Some(original - removed_above) + } + + fn should_roll_last_stack_use(&mut self, name: &str) -> bool { + let Some(remaining) = self.forced_last_uses.get_mut(name) else { + return false; + }; + if *remaining == 0 { + return false; + } + *remaining -= 1; + if *remaining != 0 { + return false; + } + self.consumed_stack_bindings.push(name.to_string()); + true + } + + fn finish(self, stack_bindings: &mut HashMap) { + for name in self.consumed_stack_bindings { + consume_stack_binding(stack_bindings, &name); + } + } +} + fn compile_expr<'i>( expr: &Expr<'i>, env: &HashMap>, @@ -4876,6 +5658,36 @@ fn compile_expr<'i>( stack_depth: &mut i64, script_size: Option, contract_constants: &HashMap>, +) -> Result<(), CompilerError> { + let mut tracking = ExprTrackingState::default(); + compile_expr_with_tracking( + expr, + env, + stack_bindings, + types, + builder, + options, + visiting, + stack_depth, + script_size, + contract_constants, + &mut tracking, + ) +} + +#[allow(clippy::too_many_arguments)] +fn compile_expr_with_tracking<'i>( + expr: &Expr<'i>, + env: &HashMap>, + stack_bindings: &HashMap, + types: &HashMap, + builder: &mut ScriptBuilder, + options: CompileOptions, + visiting: &mut HashSet, + stack_depth: &mut i64, + script_size: Option, + contract_constants: &HashMap>, + tracking: &mut ExprTrackingState, ) -> Result<(), CompilerError> { let scope = CompilationScope { env, stack_bindings, types }; match &expr.kind { @@ -4922,10 +5734,10 @@ fn compile_expr<'i>( if !visiting.insert(name.clone()) { return Err(CompilerError::CyclicIdentifier(name.clone())); } - if let Some(index) = stack_bindings.get(name) { - builder.add_i64(*index + *stack_depth)?; + if let Some(index) = tracking.current_stack_binding_index(name, stack_bindings) { + builder.add_i64(index + *stack_depth)?; *stack_depth += 1; - builder.add_op(OpPick)?; + builder.add_op(if tracking.should_roll_last_stack_use(name) { OpRoll } else { OpPick })?; visiting.remove(name); return Ok(()); } @@ -4941,7 +5753,7 @@ fn compile_expr<'i>( } } } - compile_expr( + compile_expr_with_tracking( resolved_expr, env, stack_bindings, @@ -4952,6 +5764,7 @@ fn compile_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; visiting.remove(name); return Ok(()); @@ -4960,7 +5773,7 @@ fn compile_expr<'i>( Err(CompilerError::UndefinedIdentifier(name.clone())) } ExprKind::IfElse { condition, then_expr, else_expr } => { - compile_expr( + compile_expr_with_tracking( condition, env, stack_bindings, @@ -4971,11 +5784,12 @@ fn compile_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; builder.add_op(OpIf)?; *stack_depth -= 1; let depth_before = *stack_depth; - compile_expr( + compile_expr_with_tracking( then_expr, env, stack_bindings, @@ -4986,10 +5800,11 @@ fn compile_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; builder.add_op(OpElse)?; *stack_depth = depth_before; - compile_expr( + compile_expr_with_tracking( else_expr, env, stack_bindings, @@ -5000,14 +5815,24 @@ fn compile_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; builder.add_op(OpEndIf)?; *stack_depth = depth_before + 1; Ok(()) } - ExprKind::Call { name, args, .. } => { - compile_call_expr(name.as_str(), args, &scope, builder, options, visiting, stack_depth, script_size, contract_constants) - } + ExprKind::Call { name, args, .. } => compile_call_expr( + name.as_str(), + args, + &scope, + builder, + options, + visiting, + stack_depth, + script_size, + contract_constants, + tracking, + ), ExprKind::New { name, args, .. } => match name.as_str() { "LockingBytecodeNullData" => { if args.len() != 1 { @@ -5022,7 +5847,7 @@ fn compile_expr<'i>( if args.len() != 1 { return Err(CompilerError::Unsupported("ScriptPubKeyP2PK expects a single pubkey argument".to_string())); } - compile_expr( + compile_expr_with_tracking( &args[0], env, stack_bindings, @@ -5033,6 +5858,7 @@ fn compile_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; builder.add_data(&[0x00, 0x00, OpData32])?; *stack_depth += 1; @@ -5049,7 +5875,7 @@ fn compile_expr<'i>( if args.len() != 1 { return Err(CompilerError::Unsupported("ScriptPubKeyP2SH expects a single bytes32 argument".to_string())); } - compile_expr( + compile_expr_with_tracking( &args[0], env, stack_bindings, @@ -5060,6 +5886,7 @@ fn compile_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; builder.add_data(&[0x00, 0x00])?; *stack_depth += 1; @@ -5086,7 +5913,7 @@ fn compile_expr<'i>( "ScriptPubKeyP2SHFromRedeemScript expects a single redeem_script argument".to_string(), )); } - compile_expr( + compile_expr_with_tracking( &args[0], env, stack_bindings, @@ -5097,6 +5924,7 @@ fn compile_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; builder.add_op(OpBlake2b)?; builder.add_data(&[0x00, 0x00])?; @@ -5121,7 +5949,19 @@ fn compile_expr<'i>( name => Err(CompilerError::Unsupported(format!("unknown constructor: {name}"))), }, ExprKind::Unary { op, expr } => { - compile_expr(expr, env, stack_bindings, types, builder, options, visiting, stack_depth, script_size, contract_constants)?; + compile_expr_with_tracking( + expr, + env, + stack_bindings, + types, + builder, + options, + visiting, + stack_depth, + script_size, + contract_constants, + tracking, + )?; match op { UnaryOp::Not => builder.add_op(OpNot)?, UnaryOp::Neg => builder.add_op(OpNegate)?, @@ -5144,6 +5984,7 @@ fn compile_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; compile_concat_operand( right, @@ -5156,9 +5997,10 @@ fn compile_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; } else { - compile_expr( + compile_expr_with_tracking( left, env, stack_bindings, @@ -5169,8 +6011,9 @@ fn compile_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; - compile_expr( + compile_expr_with_tracking( right, env, stack_bindings, @@ -5181,6 +6024,7 @@ fn compile_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; } match op { @@ -5258,6 +6102,7 @@ fn compile_expr<'i>( stack_depth, script_size, contract_constants, + tracking, ), ExprKind::UnarySuffix { source, kind, .. } => match kind { UnarySuffixKind::Length => compile_length_expr( @@ -5271,6 +6116,7 @@ fn compile_expr<'i>( stack_depth, script_size, contract_constants, + tracking, ), UnarySuffixKind::Reverse => Err(CompilerError::Unsupported("reverse() is not supported".to_string())), }, @@ -5295,7 +6141,7 @@ fn compile_expr<'i>( }; let element_size = fixed_type_size(&element_type) .ok_or_else(|| CompilerError::Unsupported("array element type must have known size".to_string()))?; - compile_expr( + compile_expr_with_tracking( &resolved_source, env, stack_bindings, @@ -5306,8 +6152,21 @@ fn compile_expr<'i>( stack_depth, script_size, contract_constants, + tracking, + )?; + compile_expr_with_tracking( + index, + env, + stack_bindings, + types, + builder, + options, + visiting, + stack_depth, + script_size, + contract_constants, + tracking, )?; - compile_expr(index, env, stack_bindings, types, builder, options, visiting, stack_depth, script_size, contract_constants)?; builder.add_i64(element_size)?; *stack_depth += 1; builder.add_op(OpMul)?; @@ -5326,7 +6185,7 @@ fn compile_expr<'i>( Ok(()) } ExprKind::Slice { source, start, end, .. } => { - compile_expr( + compile_expr_with_tracking( source, env, stack_bindings, @@ -5337,9 +6196,34 @@ fn compile_expr<'i>( stack_depth, script_size, contract_constants, + tracking, + )?; + compile_expr_with_tracking( + start, + env, + stack_bindings, + types, + builder, + options, + visiting, + stack_depth, + script_size, + contract_constants, + tracking, + )?; + compile_expr_with_tracking( + end, + env, + stack_bindings, + types, + builder, + options, + visiting, + stack_depth, + script_size, + contract_constants, + tracking, )?; - compile_expr(start, env, stack_bindings, types, builder, options, visiting, stack_depth, script_size, contract_constants)?; - compile_expr(end, env, stack_bindings, types, builder, options, visiting, stack_depth, script_size, contract_constants)?; builder.add_op(OpSubstr)?; *stack_depth -= 2; Ok(()) @@ -5385,7 +6269,19 @@ fn compile_expr<'i>( Ok(()) } ExprKind::Introspection { kind, index, .. } => { - compile_expr(index, env, stack_bindings, types, builder, options, visiting, stack_depth, script_size, contract_constants)?; + compile_expr_with_tracking( + index, + env, + stack_bindings, + types, + builder, + options, + visiting, + stack_depth, + script_size, + contract_constants, + tracking, + )?; match kind { IntrospectionKind::InputValue => { builder.add_op(OpTxInputAmount)?; @@ -5429,6 +6325,66 @@ fn compile_expr<'i>( } } +#[allow(clippy::too_many_arguments)] +fn compile_tracked_expr<'i>( + expr: &Expr<'i>, + env: &HashMap>, + stack_bindings: &mut HashMap, + types: &HashMap, + builder: &mut ScriptBuilder, + options: CompileOptions, + visiting: &mut HashSet, + stack_depth: &mut i64, + script_size: Option, + contract_constants: &HashMap>, +) -> Result<(), CompilerError> { + compile_tracked_expr_with_forced_last_uses( + expr, + env, + stack_bindings, + types, + builder, + options, + visiting, + stack_depth, + script_size, + contract_constants, + HashMap::new(), + ) +} + +#[allow(clippy::too_many_arguments)] +fn compile_tracked_expr_with_forced_last_uses<'i>( + expr: &Expr<'i>, + env: &HashMap>, + stack_bindings: &mut HashMap, + types: &HashMap, + builder: &mut ScriptBuilder, + options: CompileOptions, + visiting: &mut HashSet, + stack_depth: &mut i64, + script_size: Option, + contract_constants: &HashMap>, + forced_last_uses: HashMap, +) -> Result<(), CompilerError> { + let mut tracking = ExprTrackingState { consumed_stack_bindings: Vec::new(), forced_last_uses }; + let result = compile_expr_with_tracking( + expr, + env, + stack_bindings, + types, + builder, + options, + visiting, + stack_depth, + script_size, + contract_constants, + &mut tracking, + ); + tracking.finish(stack_bindings); + result +} + #[allow(clippy::too_many_arguments)] fn compile_split_part<'i>( source: &Expr<'i>, @@ -5443,11 +6399,36 @@ fn compile_split_part<'i>( stack_depth: &mut i64, script_size: Option, contract_constants: &HashMap>, + tracking: &mut ExprTrackingState, ) -> Result<(), CompilerError> { - compile_expr(source, env, stack_bindings, types, builder, options, visiting, stack_depth, script_size, contract_constants)?; + compile_expr_with_tracking( + source, + env, + stack_bindings, + types, + builder, + options, + visiting, + stack_depth, + script_size, + contract_constants, + tracking, + )?; match part { SplitPart::Left => { - compile_expr(index, env, stack_bindings, types, builder, options, visiting, stack_depth, script_size, contract_constants)?; + compile_expr_with_tracking( + index, + env, + stack_bindings, + types, + builder, + options, + visiting, + stack_depth, + script_size, + contract_constants, + tracking, + )?; builder.add_i64(0)?; *stack_depth += 1; builder.add_op(OpSwap)?; @@ -5458,7 +6439,19 @@ fn compile_split_part<'i>( SplitPart::Right => { builder.add_op(OpSize)?; *stack_depth += 1; - compile_expr(index, env, stack_bindings, types, builder, options, visiting, stack_depth, script_size, contract_constants)?; + compile_expr_with_tracking( + index, + env, + stack_bindings, + types, + builder, + options, + visiting, + stack_depth, + script_size, + contract_constants, + tracking, + )?; builder.add_op(OpSwap)?; builder.add_op(OpSubstr)?; *stack_depth -= 2; @@ -5547,6 +6540,7 @@ fn expr_is_bytes_inner<'i>( } } +#[allow(clippy::too_many_arguments)] fn compile_length_expr<'i>( expr: &Expr<'i>, env: &HashMap>, @@ -5558,6 +6552,7 @@ fn compile_length_expr<'i>( stack_depth: &mut i64, script_size: Option, contract_constants: &HashMap>, + tracking: &mut ExprTrackingState, ) -> Result<(), CompilerError> { if let ExprKind::Identifier(name) = &expr.kind { if let Some(type_name) = types.get(name) { @@ -5567,7 +6562,7 @@ fn compile_length_expr<'i>( return Ok(()); } if let Some(element_size) = array_element_size(type_name) { - compile_expr( + compile_expr_with_tracking( expr, env, stack_bindings, @@ -5578,6 +6573,7 @@ fn compile_length_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; builder.add_op(OpSize)?; builder.add_op(OpSwap)?; @@ -5595,7 +6591,19 @@ fn compile_length_expr<'i>( *stack_depth += 1; return Ok(()); } - compile_expr(expr, env, stack_bindings, types, builder, options, visiting, stack_depth, script_size, contract_constants)?; + compile_expr_with_tracking( + expr, + env, + stack_bindings, + types, + builder, + options, + visiting, + stack_depth, + script_size, + contract_constants, + tracking, + )?; builder.add_op(OpSize)?; builder.add_op(OpSwap)?; builder.add_op(OpDrop)?; @@ -5612,6 +6620,7 @@ fn compile_call_expr<'i>( stack_depth: &mut i64, script_size: Option, contract_constants: &HashMap>, + tracking: &mut ExprTrackingState, ) -> Result<(), CompilerError> { match name { "OpSha256" => compile_opcode_call( @@ -5626,12 +6635,13 @@ fn compile_call_expr<'i>( OpSHA256, script_size, contract_constants, + tracking, ), "sha256" => { if args.len() != 1 { return Err(CompilerError::Unsupported("sha256() expects a single argument".to_string())); } - compile_expr( + compile_expr_with_tracking( &args[0], scope.env, scope.stack_bindings, @@ -5642,6 +6652,7 @@ fn compile_call_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; builder.add_op(OpSHA256)?; Ok(()) @@ -5658,6 +6669,7 @@ fn compile_call_expr<'i>( OpTxSubnetId, script_size, contract_constants, + tracking, ), "OpTxGas" => compile_opcode_call( name, @@ -5671,6 +6683,7 @@ fn compile_call_expr<'i>( OpTxGas, script_size, contract_constants, + tracking, ), "OpTxPayloadLen" => compile_opcode_call( name, @@ -5684,6 +6697,7 @@ fn compile_call_expr<'i>( OpTxPayloadLen, script_size, contract_constants, + tracking, ), "OpTxPayloadSubstr" => compile_opcode_call( name, @@ -5697,6 +6711,7 @@ fn compile_call_expr<'i>( OpTxPayloadSubstr, script_size, contract_constants, + tracking, ), "OpOutpointTxId" => compile_opcode_call( name, @@ -5710,6 +6725,7 @@ fn compile_call_expr<'i>( OpOutpointTxId, script_size, contract_constants, + tracking, ), "OpOutpointIndex" => compile_opcode_call( name, @@ -5723,6 +6739,7 @@ fn compile_call_expr<'i>( OpOutpointIndex, script_size, contract_constants, + tracking, ), "OpTxInputScriptSigLen" => compile_opcode_call( name, @@ -5736,6 +6753,7 @@ fn compile_call_expr<'i>( OpTxInputScriptSigLen, script_size, contract_constants, + tracking, ), "OpTxInputScriptSigSubstr" => compile_opcode_call( name, @@ -5749,6 +6767,7 @@ fn compile_call_expr<'i>( OpTxInputScriptSigSubstr, script_size, contract_constants, + tracking, ), "OpTxInputSeq" => compile_opcode_call( name, @@ -5762,6 +6781,7 @@ fn compile_call_expr<'i>( OpTxInputSeq, script_size, contract_constants, + tracking, ), "OpTxInputIsCoinbase" => compile_opcode_call( name, @@ -5775,6 +6795,7 @@ fn compile_call_expr<'i>( OpTxInputIsCoinbase, script_size, contract_constants, + tracking, ), "OpTxInputSpkLen" => compile_opcode_call( name, @@ -5788,6 +6809,7 @@ fn compile_call_expr<'i>( OpTxInputSpkLen, script_size, contract_constants, + tracking, ), "OpTxInputSpkSubstr" => compile_opcode_call( name, @@ -5801,6 +6823,7 @@ fn compile_call_expr<'i>( OpTxInputSpkSubstr, script_size, contract_constants, + tracking, ), "OpTxOutputSpkLen" => compile_opcode_call( name, @@ -5814,6 +6837,7 @@ fn compile_call_expr<'i>( OpTxOutputSpkLen, script_size, contract_constants, + tracking, ), "OpTxOutputSpkSubstr" => compile_opcode_call( name, @@ -5827,6 +6851,7 @@ fn compile_call_expr<'i>( OpTxOutputSpkSubstr, script_size, contract_constants, + tracking, ), "OpAuthOutputCount" => compile_opcode_call( name, @@ -5840,6 +6865,7 @@ fn compile_call_expr<'i>( OpAuthOutputCount, script_size, contract_constants, + tracking, ), "OpAuthOutputIdx" => compile_opcode_call( name, @@ -5853,6 +6879,7 @@ fn compile_call_expr<'i>( OpAuthOutputIdx, script_size, contract_constants, + tracking, ), "OpInputCovenantId" => compile_opcode_call( name, @@ -5866,6 +6893,7 @@ fn compile_call_expr<'i>( OpInputCovenantId, script_size, contract_constants, + tracking, ), "OpCovInputCount" => compile_opcode_call( name, @@ -5879,6 +6907,7 @@ fn compile_call_expr<'i>( OpCovInputCount, script_size, contract_constants, + tracking, ), "OpCovInputIdx" => compile_opcode_call( name, @@ -5892,6 +6921,7 @@ fn compile_call_expr<'i>( OpCovInputIdx, script_size, contract_constants, + tracking, ), "OpCovOutputCount" => compile_opcode_call( name, @@ -5905,6 +6935,7 @@ fn compile_call_expr<'i>( OpCovOutputCount, script_size, contract_constants, + tracking, ), "OpCovOutputIdx" => compile_opcode_call( name, @@ -5918,6 +6949,7 @@ fn compile_call_expr<'i>( OpCovOutputIdx, script_size, contract_constants, + tracking, ), "OpNum2Bin" => compile_opcode_call( name, @@ -5931,6 +6963,7 @@ fn compile_call_expr<'i>( OpNum2Bin, script_size, contract_constants, + tracking, ), "OpBin2Num" => compile_opcode_call( name, @@ -5944,6 +6977,7 @@ fn compile_call_expr<'i>( OpBin2Num, script_size, contract_constants, + tracking, ), "OpChainblockSeqCommit" => compile_opcode_call( name, @@ -5957,13 +6991,14 @@ fn compile_call_expr<'i>( OpChainblockSeqCommit, script_size, contract_constants, + tracking, ), "bytes" => { if args.is_empty() || args.len() > 2 { return Err(CompilerError::Unsupported("bytes() expects one or two arguments".to_string())); } if args.len() == 2 { - compile_expr( + compile_expr_with_tracking( &args[0], scope.env, scope.stack_bindings, @@ -5974,8 +7009,9 @@ fn compile_call_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; - compile_expr( + compile_expr_with_tracking( &args[1], scope.env, scope.stack_bindings, @@ -5986,6 +7022,7 @@ fn compile_call_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; builder.add_op(OpNum2Bin)?; *stack_depth -= 1; @@ -6006,7 +7043,7 @@ fn compile_call_expr<'i>( } } if expr_is_bytes(&args[0], scope.env, scope.types) { - compile_expr( + compile_expr_with_tracking( &args[0], scope.env, scope.stack_bindings, @@ -6017,10 +7054,11 @@ fn compile_call_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; return Ok(()); } - compile_expr( + compile_expr_with_tracking( &args[0], scope.env, scope.stack_bindings, @@ -6031,6 +7069,7 @@ fn compile_call_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; builder.add_i64(8)?; *stack_depth += 1; @@ -6040,7 +7079,7 @@ fn compile_call_expr<'i>( } _ => { if expr_is_bytes(&args[0], scope.env, scope.types) { - compile_expr( + compile_expr_with_tracking( &args[0], scope.env, scope.stack_bindings, @@ -6051,10 +7090,11 @@ fn compile_call_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; Ok(()) } else { - compile_expr( + compile_expr_with_tracking( &args[0], scope.env, scope.stack_bindings, @@ -6065,6 +7105,7 @@ fn compile_call_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; builder.add_i64(8)?; *stack_depth += 1; @@ -6090,13 +7131,14 @@ fn compile_call_expr<'i>( stack_depth, script_size, contract_constants, + tracking, ) } "int" => { if args.len() != 1 { return Err(CompilerError::Unsupported("int() expects a single argument".to_string())); } - compile_expr( + compile_expr_with_tracking( &args[0], scope.env, scope.stack_bindings, @@ -6107,6 +7149,7 @@ fn compile_call_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; Ok(()) } @@ -6114,7 +7157,7 @@ fn compile_call_expr<'i>( if args.len() != 1 { return Err(CompilerError::Unsupported(format!("{name}() expects a single argument"))); } - compile_expr( + compile_expr_with_tracking( &args[0], scope.env, scope.stack_bindings, @@ -6125,6 +7168,7 @@ fn compile_call_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; Ok(()) } @@ -6135,7 +7179,7 @@ fn compile_call_expr<'i>( if args.len() != 1 && args.len() != 2 { return Err(CompilerError::Unsupported(format!("{name}() expects 1 or 2 arguments"))); } - compile_expr( + compile_expr_with_tracking( &args[0], scope.env, scope.stack_bindings, @@ -6146,10 +7190,11 @@ fn compile_call_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; if args.len() == 2 { // byte[](value, size) - OpNum2Bin with size parameter - compile_expr( + compile_expr_with_tracking( &args[1], scope.env, scope.stack_bindings, @@ -6160,6 +7205,7 @@ fn compile_call_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; *stack_depth += 1; builder.add_op(OpNum2Bin)?; @@ -6172,7 +7218,7 @@ fn compile_call_expr<'i>( if args.len() != 1 { return Err(CompilerError::Unsupported(format!("{name}() expects a single argument"))); } - compile_expr( + compile_expr_with_tracking( &args[0], scope.env, scope.stack_bindings, @@ -6183,6 +7229,7 @@ fn compile_call_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; builder.add_i64(size)?; *stack_depth += 1; @@ -6195,7 +7242,7 @@ fn compile_call_expr<'i>( if args.len() != 1 { return Err(CompilerError::Unsupported("blake2b() expects a single argument".to_string())); } - compile_expr( + compile_expr_with_tracking( &args[0], scope.env, scope.stack_bindings, @@ -6206,6 +7253,7 @@ fn compile_call_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; builder.add_op(OpBlake2b)?; Ok(()) @@ -6214,7 +7262,7 @@ fn compile_call_expr<'i>( if args.len() != 2 { return Err(CompilerError::Unsupported("checkSig() expects 2 arguments".to_string())); } - compile_expr( + compile_expr_with_tracking( &args[0], scope.env, scope.stack_bindings, @@ -6225,8 +7273,9 @@ fn compile_call_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; - compile_expr( + compile_expr_with_tracking( &args[1], scope.env, scope.stack_bindings, @@ -6237,6 +7286,7 @@ fn compile_call_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; builder.add_op(OpCheckSig)?; *stack_depth -= 1; @@ -6245,7 +7295,7 @@ fn compile_call_expr<'i>( "checkDataSig" => { // TODO: Remove this stub for arg in args { - compile_expr( + compile_expr_with_tracking( arg, scope.env, scope.stack_bindings, @@ -6256,6 +7306,7 @@ fn compile_call_expr<'i>( stack_depth, script_size, contract_constants, + tracking, )?; } for _ in 0..args.len() { @@ -6283,12 +7334,13 @@ fn compile_opcode_call<'i>( opcode: u8, script_size: Option, contract_constants: &HashMap>, + tracking: &mut ExprTrackingState, ) -> Result<(), CompilerError> { if args.len() != expected_args { return Err(CompilerError::Unsupported(format!("{name}() expects {expected_args} argument(s)"))); } for arg in args { - compile_expr( + compile_expr_with_tracking( arg, scope.env, scope.stack_bindings, @@ -6299,6 +7351,7 @@ fn compile_opcode_call<'i>( stack_depth, script_size, contract_constants, + tracking, )?; } builder.add_op(opcode)?; @@ -6306,6 +7359,7 @@ fn compile_opcode_call<'i>( Ok(()) } +#[allow(clippy::too_many_arguments)] fn compile_concat_operand<'i>( expr: &Expr<'i>, env: &HashMap>, @@ -6317,8 +7371,21 @@ fn compile_concat_operand<'i>( stack_depth: &mut i64, script_size: Option, contract_constants: &HashMap>, + tracking: &mut ExprTrackingState, ) -> Result<(), CompilerError> { - compile_expr(expr, env, stack_bindings, types, builder, options, visiting, stack_depth, script_size, contract_constants)?; + compile_expr_with_tracking( + expr, + env, + stack_bindings, + types, + builder, + options, + visiting, + stack_depth, + script_size, + contract_constants, + tracking, + )?; if !expr_is_bytes(expr, env, types) { builder.add_i64(1)?; *stack_depth += 1; diff --git a/silverscript-lang/tests/compiler_tests.rs b/silverscript-lang/tests/compiler_tests.rs index ccc5ae4b..4fdf4901 100644 --- a/silverscript-lang/tests/compiler_tests.rs +++ b/silverscript-lang/tests/compiler_tests.rs @@ -14,7 +14,7 @@ use kaspa_txscript::{ EngineCtx, EngineFlags, SeqCommitAccessor, TxScriptEngine, pay_to_address_script, pay_to_script_hash_script, pay_to_script_hash_signature_script, }; -use silverscript_lang::ast::{Expr, parse_contract_ast}; +use silverscript_lang::ast::{Expr, format_contract_ast, parse_contract_ast}; use silverscript_lang::compiler::{ CompileOptions, CompiledContract, CovenantDeclCallOptions, FunctionAbiEntry, FunctionInputAbi, compile_contract, compile_contract_ast, function_branch_index, struct_object, @@ -3250,9 +3250,10 @@ fn compiles_contract_fields_as_script_prolog() { .unwrap() .add_data(&[0x12, 0x34]) .unwrap() + // The final visible use of `x` rolls it directly into the comparison. .add_i64(1) .unwrap() - .add_op(OpPick) + .add_op(OpRoll) .unwrap() .add_i64(5) .unwrap() @@ -3260,8 +3261,7 @@ fn compiles_contract_fields_as_script_prolog() { .unwrap() .add_op(OpVerify) .unwrap() - .add_op(OpDrop) - .unwrap() + // Drop the remaining `y` field and leave success. .add_op(OpDrop) .unwrap() .add_op(OpTrue) @@ -5548,16 +5548,21 @@ fn compile_time_length_for_fixed_size_int_array() { "#; let compiled = compile_contract(source, &[], CompileOptions::default()).expect("compile succeeds"); - // Expected script for compile-time length: - // The nums.length should be replaced with a compile-time constant 5 - // require(nums.length == 5) becomes: <5> <5> OP_NUMEQUALVERIFY, then OP_TRUE for entrypoint return - let expected_script = vec![ - 0x55, // OP_5 (push 5 for nums.length) - 0x55, // OP_5 (push 5 for comparison) - 0x9c, // OP_NUMEQUALVERIFY (combined OP_NUMEQUAL + OP_VERIFY) - 0x69, // OP_VERIFY - 0x51, // OP_TRUE (entrypoint return value) - ]; + let expected_script = ScriptBuilder::new() + // `nums.length` is folded to the compile-time constant `5`. + .add_i64(5) + .unwrap() + // Compare it against the literal `5`. + .add_i64(5) + .unwrap() + .add_op(OpNumEqual) + .unwrap() + .add_op(OpVerify) + .unwrap() + // Entrypoints leave the success sentinel. + .add_op(OpTrue) + .unwrap() + .drain(); assert_eq!( compiled.script, expected_script, @@ -5578,16 +5583,21 @@ fn compile_time_length_for_fixed_size_byte_array() { "#; let compiled = compile_contract(source, &[], CompileOptions::default()).expect("compile succeeds"); - // Expected script for compile-time length: - // data.length should be replaced with a compile-time constant 3 - // require(data.length == 3) becomes: <3> <3> OP_NUMEQUALVERIFY, then OP_TRUE for entrypoint return - let expected_script = vec![ - 0x53, // OP_3 (push 3 for data.length) - 0x53, // OP_3 (push 3 for comparison) - 0x9c, // OP_NUMEQUALVERIFY (combined OP_NUMEQUAL + OP_VERIFY) - 0x69, // OP_VERIFY - 0x51, // OP_TRUE (entrypoint return value) - ]; + let expected_script = ScriptBuilder::new() + // `data.length` is folded to the compile-time constant `3`. + .add_i64(3) + .unwrap() + // Compare it against the literal `3`. + .add_i64(3) + .unwrap() + .add_op(OpNumEqual) + .unwrap() + .add_op(OpVerify) + .unwrap() + // Entrypoints leave the success sentinel. + .add_op(OpTrue) + .unwrap() + .drain(); assert_eq!( compiled.script, expected_script, @@ -5610,20 +5620,29 @@ fn compile_time_length_for_inferred_array_sizes() { "#; let compiled = compile_contract(source, &[], CompileOptions::default()).expect("compile succeeds"); - // Both lengths should be compile-time constants (no OP_SIZE path): - // require(data.length == 4) -> OP_4 OP_4 OP_NUMEQUALVERIFY - // require(nums.length == 3) -> OP_3 OP_3 OP_NUMEQUALVERIFY - let expected_script = vec![ - 0x54, // OP_4 (data.length) - 0x54, // OP_4 - 0x9c, // OP_NUMEQUALVERIFY - 0x69, // OP_VERIFY - 0x53, // OP_3 (nums.length) - 0x53, // OP_3 - 0x9c, // OP_NUMEQUALVERIFY - 0x69, // OP_VERIFY - 0x51, // OP_TRUE - ]; + let expected_script = ScriptBuilder::new() + // `data.length` is inferred as `4`. + .add_i64(4) + .unwrap() + .add_i64(4) + .unwrap() + .add_op(OpNumEqual) + .unwrap() + .add_op(OpVerify) + .unwrap() + // `nums.length` is inferred as `3`. + .add_i64(3) + .unwrap() + .add_i64(3) + .unwrap() + .add_op(OpNumEqual) + .unwrap() + .add_op(OpVerify) + .unwrap() + // Entrypoints leave the success sentinel. + .add_op(OpTrue) + .unwrap() + .drain(); assert_eq!( compiled.script, expected_script, @@ -5719,17 +5738,20 @@ fn compile_time_length_with_constant_size() { "#; let compiled = compile_contract(source, &[], CompileOptions::default()).expect("compile succeeds"); - // Expected script for compile-time length with constant size: - // nums.length should be replaced with compile-time constant 5 (from SIZE) - // SIZE constant should also be replaced with 5 - // require(nums.length == SIZE) becomes: <5> <5> OP_NUMEQUALVERIFY - let expected_script = vec![ - 0x55, // OP_5 (push 5 for nums.length) - 0x55, // OP_5 (push 5 for SIZE constant) - 0x9c, // OP_NUMEQUALVERIFY - 0x69, // OP_VERIFY - 0x51, // OP_TRUE (entrypoint return value) - ]; + let expected_script = ScriptBuilder::new() + // Both `nums.length` and `SIZE` fold to the same compile-time constant. + .add_i64(5) + .unwrap() + .add_i64(5) + .unwrap() + .add_op(OpNumEqual) + .unwrap() + .add_op(OpVerify) + .unwrap() + // Entrypoints leave the success sentinel. + .add_op(OpTrue) + .unwrap() + .drain(); assert_eq!( compiled.script, expected_script, @@ -6194,197 +6216,953 @@ fn local_alias_reassignment_from_alias_passes_for_x_5() { } #[test] -fn local_bool_expression_is_stored_once_and_reused() { +fn reassignment_pushes_new_value_and_rebinds_stack_local() { let source = r#" - contract BoolRepeat() { - entrypoint function main(int x) { - bool y = x + 1 > 1; - require(y); - require(y == true); + contract ReassignPush() { + entrypoint function main(int start) { + int x = start + 1; + x = x + 1; + require(x == 4); } } "#; - let compiled = compile_contract(source, &[], CompileOptions::default()).expect("bool local should compile"); - - assert_eq!( - compiled.script.iter().copied().filter(|op| *op == OpAdd).count(), - 1, - "x + 1 should be computed once for the stored bool expression" - ); + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("reassignment should compile"); + let expected = ScriptBuilder::new() + // Initialize `x` from the `start` argument. + .add_i64(0) + .unwrap() + .add_op(OpPick) + .unwrap() + .add_i64(1) + .unwrap() + .add_op(OpAdd) + .unwrap() + // Rebind `x` by incrementing the just-initialized stack local again. + .add_i64(1) + .unwrap() + .add_op(OpAdd) + .unwrap() + // Verify the rebound top-of-stack value. + .add_i64(4) + .unwrap() + .add_op(OpNumEqual) + .unwrap() + .add_op(OpVerify) + .unwrap() + // Drop the shadowed pre-reassignment stack binding. + .add_op(OpDrop) + .unwrap() + .add_op(OpTrue) + .unwrap() + .drain(); + assert_eq!(compiled.script, expected); - let sigscript_ok = compiled.build_sig_script("main", vec![Expr::int(5)]).expect("sigscript builds"); + let sigscript_ok = compiled.build_sig_script("main", vec![Expr::int(2)]).expect("sigscript builds"); let result_ok = run_script_with_sigscript(compiled.script.clone(), sigscript_ok); - assert!(result_ok.is_ok(), "stored bool local should execute successfully: {}", result_ok.unwrap_err()); + assert!(result_ok.is_ok(), "reassigned stack local should execute successfully: {}", result_ok.unwrap_err()); - let sigscript_err = compiled.build_sig_script("main", vec![Expr::int(0)]).expect("sigscript builds"); + let sigscript_err = compiled.build_sig_script("main", vec![Expr::int(1)]).expect("sigscript builds"); let result_err = run_script_with_sigscript(compiled.script, sigscript_err); - assert!(result_err.is_err(), "stored bool local should still enforce the false branch"); + assert!(result_err.is_err(), "reassigned stack local should still enforce the updated value"); } #[test] -fn local_nested_expression_is_stored_once_and_reused() { +fn last_use_of_reassigned_param_rolls_instead_of_picking() { let source = r#" - contract NestedRepeat() { + contract ConsumeReassignedParam() { entrypoint function main(int x) { - int y = (x + 1) * (x + 2); - require(y > 10); - require(y < 100); + x = x + 1; + require(x > 0); } } "#; - let compiled = compile_contract(source, &[], CompileOptions::default()).expect("nested local should compile"); - - assert_eq!( - compiled.script.iter().copied().filter(|op| *op == OpAdd).count(), - 2, - "the nested local expression should compute each addition once before storing the result" - ); - assert_eq!( - compiled.script.iter().copied().filter(|op| *op == OpMul).count(), - 1, - "the nested local expression should multiply once before storing the result" - ); - - let sigscript_ok = compiled.build_sig_script("main", vec![Expr::int(5)]).expect("sigscript builds"); - let result_ok = run_script_with_sigscript(compiled.script.clone(), sigscript_ok); - assert!(result_ok.is_ok(), "stored nested local should execute successfully: {}", result_ok.unwrap_err()); - - let sigscript_err = compiled.build_sig_script("main", vec![Expr::int(10)]).expect("sigscript builds"); - let result_err = run_script_with_sigscript(compiled.script, sigscript_err); - assert!(result_err.is_err(), "stored nested local should still enforce the second require"); + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("param reassignment should compile"); + let expected = ScriptBuilder::new() + .add_i64(0) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_i64(1) + .unwrap() + .add_op(OpAdd) + .unwrap() + // The final visible use of rebound `x` rolls it into the comparison. + .add_i64(0) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_i64(0) + .unwrap() + .add_op(OpGreaterThan) + .unwrap() + .add_op(OpVerify) + .unwrap() + .add_op(OpTrue) + .unwrap() + .drain(); + assert_eq!(compiled.script, expected); } #[test] -fn inline_nested_argument_expression_is_stored_once_and_reused() { +fn partially_reassigned_struct_field_rolls_last_use_without_copying_unchanged_fields() { let source = r#" - contract InlineCallRepeat() { - function f(int y) { - require(y > 10); - require(y < 100); + contract ConsumePartialStructField() { + struct S { + int a; + int b; } entrypoint function main(int x) { - f((x + 1) * (x + 2)); + S s = {a: x + 1, b: x * x}; + s = {a: s.a + 1, b: s.b}; + require(s.a > 0); + require(s.b > 0); } } "#; - let compiled = compile_contract(source, &[], CompileOptions::default()).expect("inline nested arg should compile"); - + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("partial struct reassignment should compile"); + assert_eq!( + compiled.script.iter().copied().filter(|op| *op == OpMul).count(), + 1, + "the unchanged field should keep using its original expression instead of being copied into a new stack slot" + ); assert_eq!( compiled.script.iter().copied().filter(|op| *op == OpAdd).count(), 2, - "the inline nested argument should compute each addition once and reuse the stored result" + "only the initial `s.a = x + 1` and the reassigned `s.a = s.a + 1` should emit additions" ); - assert_eq!( - compiled.script.iter().copied().filter(|op| *op == OpMul).count(), - 1, - "the inline nested argument should multiply once and reuse the stored result" + assert!( + compiled.script.iter().copied().filter(|op| *op == OpRoll).count() >= 2, + "the stack-backed struct leaves should be rebound with rolls instead of rebuilding the whole struct" ); - let sigscript_ok = compiled.build_sig_script("main", vec![Expr::int(5)]).expect("sigscript builds"); + let sigscript_ok = compiled.build_sig_script("main", vec![Expr::int(2)]).expect("sigscript builds"); let result_ok = run_script_with_sigscript(compiled.script.clone(), sigscript_ok); - assert!(result_ok.is_ok(), "stored inline nested argument should execute successfully: {}", result_ok.unwrap_err()); + assert!(result_ok.is_ok(), "partial struct reassignment should execute successfully: {}", result_ok.unwrap_err()); - let sigscript_err = compiled.build_sig_script("main", vec![Expr::int(10)]).expect("sigscript builds"); + let sigscript_err = compiled.build_sig_script("main", vec![Expr::int(0)]).expect("sigscript builds"); let result_err = run_script_with_sigscript(compiled.script, sigscript_err); - assert!(result_err.is_err(), "stored inline nested argument should still enforce the second require"); + assert!(result_err.is_err(), "partial struct reassignment should still enforce the updated field checks"); } #[test] -fn function_call_assignment_result_is_stored_once_and_reused() { +fn if_without_else_reassignment_gets_normalized() { let source = r#" - contract CallAssignRepeat() { - function g(int x) : (int) { - require(x > 0); - return(x - 17); - } - - function f(int x) : (int) { - require(x > 17); - (int base) = g(x); - int shifted = base + 2; - return(shifted * 2); - } - - entrypoint function main(int x) { - (int y) = f(x); - require(y > 1); - require(y < 10); + contract MissingElse() { + entrypoint function main(int flag) { + int x = 1; + if (flag > 0) { + x = x + 1; + } + require(x == 2 - (flag <= 0)); } } "#; - let compiled = compile_contract(source, &[], CompileOptions::default()).expect("function-call assignment should compile"); - - assert_eq!( - compiled.script.iter().copied().filter(|op| *op == OpSub).count(), - 1, - "the nested g(x) return calculation should be computed once and the assigned local reused" - ); - assert_eq!( - compiled.script.iter().copied().filter(|op| *op == OpMul).count(), - 1, - "the extra arithmetic in f(x) should be computed once and the assigned local reused" - ); - - let sigscript_ok = compiled.build_sig_script("main", vec![Expr::int(19)]).expect("sigscript builds"); - let result_ok = run_script_with_sigscript(compiled.script.clone(), sigscript_ok); - assert!(result_ok.is_ok(), "stored function-call assignment result should execute successfully: {}", result_ok.unwrap_err()); + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("if without else should compile"); + let normalized = format_contract_ast(&compiled.ast); + assert!(normalized.contains("else"), "normalized AST should include an else branch: {normalized}"); + assert!(normalized.contains("__if_x_"), "normalized AST should introduce an if-local binding for x: {normalized}"); - let sigscript_err = compiled.build_sig_script("main", vec![Expr::int(29)]).expect("sigscript builds"); - let result_err = run_script_with_sigscript(compiled.script, sigscript_err); - assert!(result_err.is_err(), "stored function-call assignment result should still enforce the second require"); + assert!(compiled.script.contains(&OpElse), "compiled script should include an explicit else branch"); } #[test] -fn struct_return_field_is_stored_once_and_reused() { +fn normalized_if_branches_roll_their_last_use_of_reassigned_param() { let source = r#" - contract StructFieldRepeat() { - struct S { - int a; - int b; - } - - function f(int x) : (S) { - return({ - a: x + 1, - b: x * x, - }); - } - + contract MissingElseConsume() { entrypoint function main(int x) { - (S s) = f(x); - require(s.a < 10); - require(s.b < 20); - require(s.a > 1); - require(s.b > 2); + if (true) { + x = x + 1; + } + require(x > 0); } } "#; - let compiled = compile_contract(source, &[], CompileOptions::default()).expect("struct-return local should compile"); - - assert_eq!( - compiled.script.iter().copied().filter(|op| *op == OpAdd).count(), - 1, - "s.a should be computed once and reused across both require statements" - ); - assert_eq!( - compiled.script.iter().copied().filter(|op| *op == OpMul).count(), - 1, - "s.b should be computed once and reused across both require statements" - ); - - let sigscript_ok = compiled.build_sig_script("main", vec![Expr::int(3)]).expect("sigscript builds"); - let result_ok = run_script_with_sigscript(compiled.script.clone(), sigscript_ok); - assert!(result_ok.is_ok(), "stored struct fields should execute successfully: {}", result_ok.unwrap_err()); - - let sigscript_err = compiled.build_sig_script("main", vec![Expr::int(10)]).expect("sigscript builds"); - let result_err = run_script_with_sigscript(compiled.script, sigscript_err); - assert!(result_err.is_err(), "stored struct fields should still enforce the require conditions"); + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("normalized if should compile"); + let expected = ScriptBuilder::new() + .add_i64(1) + .unwrap() + .add_op(OpIf) + .unwrap() + .add_i64(0) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_i64(1) + .unwrap() + .add_op(OpAdd) + .unwrap() + .add_i64(0) + .unwrap() + .add_op(OpPick) + .unwrap() + .add_i64(1) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_op(OpDrop) + .unwrap() + .add_op(OpElse) + .unwrap() + .add_i64(0) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_i64(0) + .unwrap() + .add_op(OpPick) + .unwrap() + .add_i64(1) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_op(OpDrop) + .unwrap() + .add_op(OpEndIf) + .unwrap() + .add_i64(0) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_i64(0) + .unwrap() + .add_op(OpGreaterThan) + .unwrap() + .add_op(OpVerify) + .unwrap() + .add_op(OpTrue) + .unwrap() + .drain(); + assert_eq!(compiled.script, expected); +} + +#[test] +fn if_normalization_mirrors_missing_reassigned_vars_across_branches() { + let source = r#" + contract BranchMirror() { + entrypoint function main(int flag, int x, int y, int expected_x, int expected_y) { + if (flag > 0) { + x = x + 1; + } else { + y = y * 2; + } + require(x == expected_x); + require(y == expected_y); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("branch mirroring should compile"); + let normalized = format_contract_ast(&compiled.ast); + assert!(normalized.contains("__if_x_"), "normalized AST should materialize x in both branches: {normalized}"); + assert!(normalized.contains("__if_y_"), "normalized AST should materialize y in both branches: {normalized}"); + + assert!(compiled.script.contains(&OpElse), "compiled script should keep both mirrored branches"); + let expected = ScriptBuilder::new() + // Pick `flag` from the stack. + .add_i64(4) + .unwrap() + .add_op(OpPick) + .unwrap() + // Compare `flag > 0`. + .add_i64(0) + .unwrap() + .add_op(OpGreaterThan) + .unwrap() + // Branch on the result of `flag > 0`. + .add_op(OpIf) + .unwrap() + // Pick `x` and push `x + 1`. + .add_i64(3) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_i64(1) + .unwrap() + .add_op(OpAdd) + .unwrap() + // Mirror the branch outputs so both branches leave the same stack shape. + .add_i64(3) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_i64(1) + .unwrap() + .add_op(OpPick) + .unwrap() + .add_i64(1) + .unwrap() + .add_op(OpPick) + .unwrap() + // Drop the original `flag` and stale `x` inputs after mirroring them. + .add_i64(3) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_op(OpDrop) + .unwrap() + .add_i64(2) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_op(OpDrop) + .unwrap() + // Start the `else` branch. + .add_op(OpElse) + .unwrap() + // Pick `y` and push `y * 2`. + .add_i64(2) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_i64(2) + .unwrap() + .add_op(OpMul) + .unwrap() + // Mirror the branch outputs so the else branch matches the then-branch shape. + .add_i64(3) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_i64(0) + .unwrap() + .add_op(OpPick) + .unwrap() + .add_i64(2) + .unwrap() + .add_op(OpPick) + .unwrap() + // Drop the original `flag` and stale `y` inputs after mirroring them. + .add_i64(3) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_op(OpDrop) + .unwrap() + .add_i64(2) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_op(OpDrop) + .unwrap() + // Finish the branch. + .add_op(OpEndIf) + .unwrap() + // The final visible uses of mirrored `x` and `y` roll directly into their comparisons. + .add_i64(1) + .unwrap() + .add_op(OpRoll) + .unwrap() + // Pick `expected_x`. + .add_i64(3) + .unwrap() + .add_op(OpRoll) + .unwrap() + // Require `x == expected_x`. + .add_op(OpNumEqual) + .unwrap() + .add_op(OpVerify) + .unwrap() + // Pick the mirrored `y`. + .add_i64(0) + .unwrap() + .add_op(OpRoll) + .unwrap() + // Pick `expected_y`. + .add_i64(1) + .unwrap() + .add_op(OpRoll) + .unwrap() + // Require `y == expected_y`. + .add_op(OpNumEqual) + .unwrap() + .add_op(OpVerify) + .unwrap() + // Drop the remaining `flag` argument and leave success. + .add_op(OpDrop) + .unwrap() + // Leave the success sentinel. + .add_op(OpTrue) + .unwrap() + .drain(); + assert_eq!(compiled.script, expected); +} + +#[test] +fn if_reassignment_pushes_shared_outputs_in_branch_order() { + let source = r#" + contract BranchReassignOrder() { + entrypoint function main(int flag, int expected_x, int expected_y) { + int x = 1; + int y = 2; + if (flag > 0) { + int z = 5; + x = x + z; + x = x + 1; + y = y * 2; + } else { + y = y * 3; + x = x + y; + } + require(x == expected_x && y == expected_y); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("branch reassignment ordering should compile"); + let normalized = format_contract_ast(&compiled.ast); + assert!(normalized.contains("__if_x_"), "normalized AST should route x through an if-local binding: {normalized}"); + assert!(normalized.contains("__if_y_"), "normalized AST should route y through an if-local binding: {normalized}"); + + assert!(compiled.script.contains(&OpElse), "compiled script should encode both branch outputs"); +} + +// Regresses branch-local cleanup when an `if` branch rebinds an outer stack local: +// the branch temporary must be removed from its real stack depth without leaving a +// hidden shadow binding behind after the branch merges. +#[test] +fn if_branch_reassignment_drops_hidden_shadow_bindings() { + let source = r#" + contract BranchShadowCleanup() { + entrypoint function main(int flag, int a, int b, int expected) { + int d = a + b; + d = d - a; + if (flag > 0) { + int c = d + b; + d = a + c; + } else { + d = d + a; + } + require(d == expected); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("if branch reassignment should compile"); + + let sigscript_then = + compiled.build_sig_script("main", vec![Expr::int(1), Expr::int(1), Expr::int(1), Expr::int(3)]).expect("sigscript builds"); + let result_then = run_script_with_sigscript(compiled.script.clone(), sigscript_then); + assert!(result_then.is_ok(), "then-branch reassignment should leave a clean stack: {}", result_then.unwrap_err()); + + let sigscript_else = + compiled.build_sig_script("main", vec![Expr::int(0), Expr::int(1), Expr::int(1), Expr::int(2)]).expect("sigscript builds"); + let result_else = run_script_with_sigscript(compiled.script, sigscript_else); + assert!(result_else.is_ok(), "else-branch reassignment should leave a clean stack: {}", result_else.unwrap_err()); +} + +#[test] +fn struct_if_without_else_reassignment_gets_normalized() { + let source = r#" + contract MissingElseStruct() { + struct S { + int a; + int b; + } + + entrypoint function main(int flag) { + S s = {a: 1, b: 2}; + if (flag > 0) { + s = {a: s.a + 1, b: s.b + 1}; + } + require(s.a > 0); + require(s.b > 0); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("struct if without else should compile"); + let normalized = format_contract_ast(&compiled.ast); + assert!(normalized.contains("else"), "normalized AST should include an else branch: {normalized}"); + assert!(normalized.contains("__if_s_"), "normalized AST should mirror the struct binding itself: {normalized}"); + assert!(compiled.script.contains(&OpElse), "compiled script should include an explicit else branch"); +} + +#[test] +fn partial_struct_if_without_else_reassignment_gets_normalized() { + let source = r#" + contract MissingElsePartialStruct() { + struct S { + int a; + int b; + } + + entrypoint function main(int flag, int expected_a, int expected_b) { + S s = {a: 1, b: 2}; + if (flag > 0) { + s = {a: s.a + 1, b: s.b}; + } + require(s.a == expected_a); + require(s.b == expected_b); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("partial struct if without else should compile"); + let normalized = format_contract_ast(&compiled.ast); + assert!(normalized.contains("else"), "normalized AST should include an else branch: {normalized}"); + assert!(normalized.contains("__if_s_"), "normalized AST should mirror the struct binding itself: {normalized}"); + assert!(compiled.script.contains(&OpElse), "compiled script should include an explicit else branch"); + + let sigscript_else = compiled.build_sig_script("main", vec![Expr::int(0), Expr::int(1), Expr::int(2)]).expect("sigscript builds"); + let result_else = run_script_with_sigscript(compiled.script.clone(), sigscript_else); + assert!(result_else.is_ok(), "flag=0 should preserve the untouched field through normalization: {}", result_else.unwrap_err()); + + let sigscript_then = compiled.build_sig_script("main", vec![Expr::int(1), Expr::int(2), Expr::int(2)]).expect("sigscript builds"); + let result_then = run_script_with_sigscript(compiled.script, sigscript_then); + assert!(result_then.is_ok(), "flag=1 should update only the changed field through normalization: {}", result_then.unwrap_err()); +} + +#[test] +fn nested_struct_if_reassignment_materializes_missing_synthetic_fields() { + let source = r#" + contract NestedStructIf() { + struct S { + int a; + int b; + } + + entrypoint function main(int flag, int expected_a, int expected_b) { + S s = {a: 1, b: 2}; + if (flag > 0) { + if (flag > 1) { + s = {a: s.a + 1, b: s.b + 1}; + } + } + require(s.a == expected_a); + require(s.b == expected_b); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("nested struct if reassignment should compile"); + let normalized = format_contract_ast(&compiled.ast); + assert!(normalized.contains("__if_s_"), "nested normalization should synthesize struct temporaries: {normalized}"); + + let sigscript_else = compiled.build_sig_script("main", vec![Expr::int(0), Expr::int(1), Expr::int(2)]).expect("sigscript builds"); + let result_else = run_script_with_sigscript(compiled.script.clone(), sigscript_else); + assert!(result_else.is_ok(), "flag=0 should preserve the original struct fields: {}", result_else.unwrap_err()); + + let sigscript_then = compiled.build_sig_script("main", vec![Expr::int(2), Expr::int(2), Expr::int(3)]).expect("sigscript builds"); + let result_then = run_script_with_sigscript(compiled.script, sigscript_then); + assert!(result_then.is_ok(), "flag=2 should update both struct fields through the nested branch: {}", result_then.unwrap_err()); +} + +#[test] +fn nested_partial_struct_if_reassignment_materializes_missing_synthetic_fields() { + let source = r#" + contract NestedPartialStructIf() { + struct S { + int a; + int b; + } + + entrypoint function main(int flag, int expected_a, int expected_b) { + S s = {a: 1, b: 2}; + if (flag > 0) { + if (flag > 1) { + s = {a: s.a + 1, b: s.b}; + } + } + require(s.a == expected_a); + require(s.b == expected_b); + } + } + "#; + + let compiled = + compile_contract(source, &[], CompileOptions::default()).expect("nested partial struct if reassignment should compile"); + let normalized = format_contract_ast(&compiled.ast); + assert!(normalized.contains("__if_s_"), "nested normalization should synthesize struct temporaries: {normalized}"); + + let sigscript_else = compiled.build_sig_script("main", vec![Expr::int(0), Expr::int(1), Expr::int(2)]).expect("sigscript builds"); + let result_else = run_script_with_sigscript(compiled.script.clone(), sigscript_else); + assert!(result_else.is_ok(), "flag=0 should preserve both struct fields: {}", result_else.unwrap_err()); + + let sigscript_then = compiled.build_sig_script("main", vec![Expr::int(2), Expr::int(2), Expr::int(2)]).expect("sigscript builds"); + let result_then = run_script_with_sigscript(compiled.script, sigscript_then); + assert!( + result_then.is_ok(), + "flag=2 should update only the changed field through the nested branch: {}", + result_then.unwrap_err() + ); +} + +#[test] +fn struct_if_normalization_mirrors_missing_reassigned_fields_across_branches() { + let source = r#" + contract StructBranchMirror() { + struct S { + int a; + int b; + } + + entrypoint function main(int flag, int expected_a, int expected_b) { + S s = {a: 2, b: 3}; + if (flag > 0) { + s = {a: s.a + 1, b: s.b}; + } else { + s = {a: s.a, b: s.b * 2}; + } + require(s.a == expected_a); + require(s.b == expected_b); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("struct branch mirroring should compile"); + let normalized = format_contract_ast(&compiled.ast); + assert!(normalized.contains("__if_s_"), "normalized AST should materialize the struct binding in both branches: {normalized}"); + assert!(compiled.script.contains(&OpElse), "compiled script should keep both mirrored branches"); + + let sigscript_then = compiled.build_sig_script("main", vec![Expr::int(1), Expr::int(3), Expr::int(3)]).expect("sigscript builds"); + let result_then = run_script_with_sigscript(compiled.script.clone(), sigscript_then); + assert!(result_then.is_ok(), "then-branch struct mirroring should execute successfully: {}", result_then.unwrap_err()); + + let sigscript_else = compiled.build_sig_script("main", vec![Expr::int(0), Expr::int(2), Expr::int(6)]).expect("sigscript builds"); + let result_else = run_script_with_sigscript(compiled.script, sigscript_else); + assert!(result_else.is_ok(), "else-branch struct mirroring should execute successfully: {}", result_else.unwrap_err()); +} + +#[test] +fn struct_if_reassignment_preserves_types_after_merge() { + let source = r#" + contract StructMergeTypes() { + struct S { + int a; + int b; + } + + function verify_pair(S value, int expected_a, int expected_b) { + require(value.a == expected_a); + require(value.b == expected_b); + } + + entrypoint function main(int flag, int expected_a, int expected_b) { + S s = {a: 2, b: 3}; + if (flag > 0) { + s = {a: s.a + 1, b: s.b + 1}; + } else { + s = {a: s.a + 2, b: s.b + 2}; + } + S t = s; + verify_pair(t, expected_a, expected_b); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("post-if struct type merge should compile"); + let normalized = format_contract_ast(&compiled.ast); + assert!(normalized.contains("S t = s;"), "merged struct type should still allow assignment after the if: {normalized}"); +} + +#[test] +fn partial_struct_if_reassignment_preserves_types_after_merge() { + let source = r#" + contract PartialStructMergeTypes() { + struct S { + int a; + int b; + } + + function verify_pair(S value, int expected_a, int expected_b) { + require(value.a == expected_a); + require(value.b == expected_b); + } + + entrypoint function main(int flag, int expected_a, int expected_b) { + S s = {a: 2, b: 3}; + if (flag > 0) { + s = {a: s.a + 1, b: s.b}; + } else { + s = {a: s.a, b: s.b + 2}; + } + S t = s; + verify_pair(t, expected_a, expected_b); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("post-if partial struct type merge should compile"); + let normalized = format_contract_ast(&compiled.ast); + assert!(normalized.contains("S t = s;"), "merged struct type should still allow assignment after the if: {normalized}"); +} + +#[test] +fn struct_if_branch_reassignment_drops_hidden_shadow_bindings() { + let source = r#" + contract StructBranchCleanup() { + struct S { + int a; + int b; + } + + entrypoint function main(int flag, int x, int y, int expected_a, int expected_b) { + S s = {a: x, b: y}; + if (flag > 0) { + S t = {a: s.a + 1, b: s.b + 2}; + s = {a: t.a + y, b: t.b + x}; + } else { + S t = {a: s.a + x, b: s.b + y}; + s = {a: t.a + 1, b: t.b + 1}; + } + require(s.a == expected_a); + require(s.b == expected_b); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("struct branch cleanup should compile"); + + let sigscript_then = compiled + .build_sig_script("main", vec![Expr::int(1), Expr::int(2), Expr::int(3), Expr::int(6), Expr::int(7)]) + .expect("sigscript builds"); + let result_then = run_script_with_sigscript(compiled.script.clone(), sigscript_then); + assert!(result_then.is_ok(), "then-branch struct cleanup should leave a clean stack: {}", result_then.unwrap_err()); + + let sigscript_else = compiled + .build_sig_script("main", vec![Expr::int(0), Expr::int(2), Expr::int(3), Expr::int(5), Expr::int(7)]) + .expect("sigscript builds"); + let result_else = run_script_with_sigscript(compiled.script, sigscript_else); + assert!(result_else.is_ok(), "else-branch struct cleanup should leave a clean stack: {}", result_else.unwrap_err()); +} + +#[test] +fn partial_struct_if_branch_reassignment_drops_hidden_shadow_bindings() { + let source = r#" + contract PartialStructBranchCleanup() { + struct S { + int a; + int b; + } + + entrypoint function main(int flag, int x, int y, int expected_a, int expected_b) { + S s = {a: x, b: y}; + if (flag > 0) { + S t = {a: s.a + 1, b: s.b}; + s = {a: t.a + y, b: s.b}; + } else { + S t = {a: s.a, b: s.b + y}; + s = {a: s.a, b: t.b + x}; + } + require(s.a == expected_a); + require(s.b == expected_b); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("partial struct branch cleanup should compile"); + + let sigscript_then = compiled + .build_sig_script("main", vec![Expr::int(1), Expr::int(2), Expr::int(3), Expr::int(6), Expr::int(3)]) + .expect("sigscript builds"); + let result_then = run_script_with_sigscript(compiled.script.clone(), sigscript_then); + assert!(result_then.is_ok(), "then-branch partial struct cleanup should leave a clean stack: {}", result_then.unwrap_err()); + + let sigscript_else = compiled + .build_sig_script("main", vec![Expr::int(0), Expr::int(2), Expr::int(3), Expr::int(2), Expr::int(8)]) + .expect("sigscript builds"); + let result_else = run_script_with_sigscript(compiled.script, sigscript_else); + assert!(result_else.is_ok(), "else-branch partial struct cleanup should leave a clean stack: {}", result_else.unwrap_err()); +} + +#[test] +fn local_bool_expression_is_stored_once_and_reused() { + let source = r#" + contract BoolRepeat() { + entrypoint function main(int x) { + bool y = x + 1 > 1; + require(y); + require(y == true); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("bool local should compile"); + + assert_eq!( + compiled.script.iter().copied().filter(|op| *op == OpAdd).count(), + 1, + "x + 1 should be computed once for the stored bool expression" + ); + + let sigscript_ok = compiled.build_sig_script("main", vec![Expr::int(5)]).expect("sigscript builds"); + let result_ok = run_script_with_sigscript(compiled.script.clone(), sigscript_ok); + assert!(result_ok.is_ok(), "stored bool local should execute successfully: {}", result_ok.unwrap_err()); + + let sigscript_err = compiled.build_sig_script("main", vec![Expr::int(0)]).expect("sigscript builds"); + let result_err = run_script_with_sigscript(compiled.script, sigscript_err); + assert!(result_err.is_err(), "stored bool local should still enforce the false branch"); +} + +#[test] +fn local_nested_expression_is_stored_once_and_reused() { + let source = r#" + contract NestedRepeat() { + entrypoint function main(int x) { + int y = (x + 1) * (x + 2); + require(y > 10); + require(y < 100); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("nested local should compile"); + + assert_eq!( + compiled.script.iter().copied().filter(|op| *op == OpAdd).count(), + 2, + "the nested local expression should compute each addition once before storing the result" + ); + assert_eq!( + compiled.script.iter().copied().filter(|op| *op == OpMul).count(), + 1, + "the nested local expression should multiply once before storing the result" + ); + + let sigscript_ok = compiled.build_sig_script("main", vec![Expr::int(5)]).expect("sigscript builds"); + let result_ok = run_script_with_sigscript(compiled.script.clone(), sigscript_ok); + assert!(result_ok.is_ok(), "stored nested local should execute successfully: {}", result_ok.unwrap_err()); + + let sigscript_err = compiled.build_sig_script("main", vec![Expr::int(10)]).expect("sigscript builds"); + let result_err = run_script_with_sigscript(compiled.script, sigscript_err); + assert!(result_err.is_err(), "stored nested local should still enforce the second require"); +} + +#[test] +fn inline_nested_argument_expression_is_stored_once_and_reused() { + let source = r#" + contract InlineCallRepeat() { + function f(int y) { + require(y > 10); + require(y < 100); + } + + entrypoint function main(int x) { + f((x + 1) * (x + 2)); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("inline nested arg should compile"); + + assert_eq!( + compiled.script.iter().copied().filter(|op| *op == OpAdd).count(), + 2, + "the inline nested argument should compute each addition once and reuse the stored result" + ); + assert_eq!( + compiled.script.iter().copied().filter(|op| *op == OpMul).count(), + 1, + "the inline nested argument should multiply once and reuse the stored result" + ); + + let sigscript_ok = compiled.build_sig_script("main", vec![Expr::int(5)]).expect("sigscript builds"); + let result_ok = run_script_with_sigscript(compiled.script.clone(), sigscript_ok); + assert!(result_ok.is_ok(), "stored inline nested argument should execute successfully: {}", result_ok.unwrap_err()); + + let sigscript_err = compiled.build_sig_script("main", vec![Expr::int(10)]).expect("sigscript builds"); + let result_err = run_script_with_sigscript(compiled.script, sigscript_err); + assert!(result_err.is_err(), "stored inline nested argument should still enforce the second require"); +} + +#[test] +fn function_call_assignment_result_is_stored_once_and_reused() { + let source = r#" + contract CallAssignRepeat() { + function g(int x) : (int) { + require(x > 0); + return(x - 17); + } + + function f(int x) : (int) { + require(x > 17); + (int base) = g(x); + int shifted = base + 2; + return(shifted * 2); + } + + entrypoint function main(int x) { + (int y) = f(x); + require(y > 1); + require(y < 10); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("function-call assignment should compile"); + + assert_eq!( + compiled.script.iter().copied().filter(|op| *op == OpSub).count(), + 1, + "the nested g(x) return calculation should be computed once and the assigned local reused" + ); + assert_eq!( + compiled.script.iter().copied().filter(|op| *op == OpMul).count(), + 1, + "the extra arithmetic in f(x) should be computed once and the assigned local reused" + ); + + let sigscript_ok = compiled.build_sig_script("main", vec![Expr::int(19)]).expect("sigscript builds"); + let result_ok = run_script_with_sigscript(compiled.script.clone(), sigscript_ok); + assert!(result_ok.is_ok(), "stored function-call assignment result should execute successfully: {}", result_ok.unwrap_err()); + + let sigscript_err = compiled.build_sig_script("main", vec![Expr::int(29)]).expect("sigscript builds"); + let result_err = run_script_with_sigscript(compiled.script, sigscript_err); + assert!(result_err.is_err(), "stored function-call assignment result should still enforce the second require"); +} + +#[test] +fn struct_return_field_is_stored_once_and_reused() { + let source = r#" + contract StructFieldRepeat() { + struct S { + int a; + int b; + } + + function f(int x) : (S) { + return({ + a: x + 1, + b: x * x, + }); + } + + entrypoint function main(int x) { + (S s) = f(x); + require(s.a < 10); + require(s.b < 20); + require(s.a > 1); + require(s.b > 2); + } + } + "#; + + let compiled = compile_contract(source, &[], CompileOptions::default()).expect("struct-return local should compile"); + + assert_eq!( + compiled.script.iter().copied().filter(|op| *op == OpAdd).count(), + 1, + "s.a should be computed once and reused across both require statements" + ); + assert_eq!( + compiled.script.iter().copied().filter(|op| *op == OpMul).count(), + 1, + "s.b should be computed once and reused across both require statements" + ); + + let sigscript_ok = compiled.build_sig_script("main", vec![Expr::int(3)]).expect("sigscript builds"); + let result_ok = run_script_with_sigscript(compiled.script.clone(), sigscript_ok); + assert!(result_ok.is_ok(), "stored struct fields should execute successfully: {}", result_ok.unwrap_err()); + + let sigscript_err = compiled.build_sig_script("main", vec![Expr::int(10)]).expect("sigscript builds"); + let result_err = run_script_with_sigscript(compiled.script, sigscript_err); + assert!(result_err.is_err(), "stored struct fields should still enforce the require conditions"); } #[test] @@ -6507,6 +7285,7 @@ fn compile_time_if_branch_stores_struct_fields_once_and_reuses_them() { .unwrap() .add_op(OpIf) .unwrap() + // Compute `s.a = x + 1`. .add_i64(0) .unwrap() .add_op(OpPick) @@ -6515,6 +7294,7 @@ fn compile_time_if_branch_stores_struct_fields_once_and_reuses_them() { .unwrap() .add_op(OpAdd) .unwrap() + // Compute `s.b = x * x`. .add_i64(1) .unwrap() .add_op(OpPick) @@ -6525,6 +7305,7 @@ fn compile_time_if_branch_stores_struct_fields_once_and_reuses_them() { .unwrap() .add_op(OpMul) .unwrap() + // Reuse `s.a` and `s.b` across both require statements. .add_i64(1) .unwrap() .add_op(OpPick) @@ -6565,6 +7346,11 @@ fn compile_time_if_branch_stores_struct_fields_once_and_reuses_them() { .unwrap() .add_op(OpVerify) .unwrap() + // Drop the cached struct fields from their actual stack positions. + .add_i64(1) + .unwrap() + .add_op(OpRoll) + .unwrap() .add_op(OpDrop) .unwrap() .add_op(OpDrop) @@ -6577,6 +7363,7 @@ fn compile_time_if_branch_stores_struct_fields_once_and_reuses_them() { .unwrap() .add_op(OpEndIf) .unwrap() + // Drop the original argument and leave success. .add_op(OpDrop) .unwrap() .add_op(OpTrue) @@ -6585,3 +7372,245 @@ fn compile_time_if_branch_stores_struct_fields_once_and_reuses_them() { assert_eq!(compiled.script, expected); } + +#[test] +fn conditional_counter_in_unrolled_loop_stays_linear() { + const SOURCE: &str = r#" +pragma silverscript ^0.1.0; + +contract CounterLoop(int BOUND) { + entrypoint function main() { + int count = 0; + // Keep this loop small so regressions fail fast (the previous exponential blow-up + // already manifested at single-digit iteration counts). + for (i, 0, BOUND, BOUND) { + if (true) { + count = count + 1; + } + } + require(count >= 0); + } +} +"#; + + let bounds = [4i64, 8i64, 12i64]; + let mut lens = Vec::new(); + for b in bounds { + let args = [Expr::int(b)]; + let compiled = compile_contract(SOURCE, &args, CompileOptions::default()).expect("compile succeeds"); + lens.push(compiled.script.len()); + } + + assert!(lens[0] < lens[1] && lens[1] < lens[2], "expected monotonic growth, got {lens:?}"); + let d1 = lens[1] - lens[0]; + let d2 = lens[2] - lens[1]; + + assert!(d2 <= d1 * 2, "unexpected superlinear growth: lens={lens:?} d1={d1} d2={d2}"); + + // Absolute cap: the old exponential behavior already blew past this by bound=8..12. + assert!(lens[2] < 5_000, "unexpected script size: lens={lens:?}"); +} + +#[test] +fn struct_conditional_counter_in_unrolled_loop_stays_linear() { + const SOURCE: &str = r#" +pragma silverscript ^0.1.0; + +contract StructCounterLoop(int BOUND) { + struct S { + int a; + int b; + } + + entrypoint function main() { + S s = {a: 0, b: 0}; + // Keep this loop small so regressions fail fast (the previous exponential blow-up + // already manifested at single-digit iteration counts). + for (i, 0, BOUND, BOUND) { + if (true) { + s = {a: s.a + 1, b: s.b + 1}; + } + } + require(s.a >= 0); + require(s.b >= 0); + } +} +"#; + + let bounds = [4i64, 8i64, 12i64]; + let mut lens = Vec::new(); + for b in bounds { + let args = [Expr::int(b)]; + let compiled = compile_contract(SOURCE, &args, CompileOptions::default()).expect("compile succeeds"); + lens.push(compiled.script.len()); + } + + assert!(lens[0] < lens[1] && lens[1] < lens[2], "expected monotonic growth, got {lens:?}"); + let d1 = lens[1] - lens[0]; + let d2 = lens[2] - lens[1]; + + assert!(d2 <= d1 * 2, "unexpected superlinear growth: lens={lens:?} d1={d1} d2={d2}"); + + // Absolute cap: the old exponential behavior already blew past this by bound=8..12. + assert!(lens[2] < 10_000, "unexpected script size: lens={lens:?}"); +} + +#[test] +fn conditional_counter_in_unrolled_loop_matches_expected_script_with_bound_3() { + let source = r#" + contract CounterLoop(int BOUND) { + entrypoint function main() { + int count = 0; + for (i, 0, BOUND, BOUND) { + if (true) { + count = count + 1; + } + } + require(count >= 0); + } + } + "#; + + let compiled = compile_contract(source, &[Expr::int(3)], CompileOptions::default()).expect("compile succeeds"); + + let expected_script = ScriptBuilder::new() + // Iteration 0 guard has folded to `true`. + .add_i64(1) + .unwrap() + .add_op(OpIf) + .unwrap() + // Increment `count`. + .add_i64(0) + .unwrap() + .add_i64(1) + .unwrap() + .add_op(OpAdd) + .unwrap() + // Mirror the synthetic branch output and drop the stale pre-branch binding. + .add_i64(0) + .unwrap() + .add_op(OpPick) + .unwrap() + .add_i64(1) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_op(OpDrop) + .unwrap() + .add_op(OpElse) + .unwrap() + // Missing `else` branch still mirrors the old `count`. + .add_i64(0) + .unwrap() + .add_i64(0) + .unwrap() + .add_op(OpPick) + .unwrap() + .add_i64(1) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_op(OpDrop) + .unwrap() + .add_op(OpEndIf) + .unwrap() + // Iteration 1. + .add_i64(1) + .unwrap() + .add_op(OpIf) + .unwrap() + .add_i64(0) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_i64(1) + .unwrap() + .add_op(OpAdd) + .unwrap() + .add_i64(0) + .unwrap() + .add_op(OpPick) + .unwrap() + .add_i64(1) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_op(OpDrop) + .unwrap() + .add_op(OpElse) + .unwrap() + .add_i64(0) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_i64(0) + .unwrap() + .add_op(OpPick) + .unwrap() + .add_i64(1) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_op(OpDrop) + .unwrap() + .add_op(OpEndIf) + .unwrap() + // Iteration 2. + .add_i64(1) + .unwrap() + .add_op(OpIf) + .unwrap() + .add_i64(0) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_i64(1) + .unwrap() + .add_op(OpAdd) + .unwrap() + .add_i64(0) + .unwrap() + .add_op(OpPick) + .unwrap() + .add_i64(1) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_op(OpDrop) + .unwrap() + .add_op(OpElse) + .unwrap() + .add_i64(0) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_i64(0) + .unwrap() + .add_op(OpPick) + .unwrap() + .add_i64(1) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_op(OpDrop) + .unwrap() + .add_op(OpEndIf) + .unwrap() + // The final visible use of `count` rolls directly into the comparison. + .add_i64(0) + .unwrap() + .add_op(OpRoll) + .unwrap() + .add_i64(0) + .unwrap() + .add_op(OpGreaterThanOrEqual) + .unwrap() + .add_op(OpVerify) + .unwrap() + // Leave success. + .add_op(OpTrue) + .unwrap() + .drain(); + + assert_eq!(compiled.script, expected_script); +}