diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 2e759ad21..4233042db 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -154,6 +154,7 @@ def optimize_code( is_async: bool = False, n_candidates: int = 5, is_numerical_code: bool | None = None, + rerun_trace_id: str | None = None, ) -> list[OptimizedCandidate]: """Optimize the given code for performance by making a request to the Django endpoint. @@ -194,6 +195,7 @@ def optimize_code( "call_sequence": self.get_next_sequence(), "n_candidates": n_candidates, "is_numerical_code": is_numerical_code, + "rerun_trace_id": rerun_trace_id, } self.add_language_metadata(payload, language_version, module_system) @@ -234,6 +236,7 @@ def optimize_python_code_line_profiler( is_numerical_code: bool | None = None, language: str = "python", language_version: str | None = None, + rerun_trace_id: str | None = None, ) -> list[OptimizedCandidate]: """Optimize code for performance using line profiler results. @@ -272,6 +275,7 @@ def optimize_python_code_line_profiler( "codeflash_version": codeflash_version, "call_sequence": self.get_next_sequence(), "is_numerical_code": is_numerical_code, + "rerun_trace_id": rerun_trace_id, } try: @@ -318,7 +322,9 @@ def adaptive_optimize(self, request: AIServiceAdaptiveOptimizeRequest) -> Optimi self.log_error_response(response, "generating optimized candidates", "cli-optimize-error-response") return None - def optimize_code_refinement(self, request: list[AIServiceRefinerRequest]) -> list[OptimizedCandidate]: + def optimize_code_refinement( + self, request: list[AIServiceRefinerRequest], rerun_trace_id: str | None = None + ) -> list[OptimizedCandidate]: """Refine optimization candidates for improved performance. Supports Python, JavaScript, and TypeScript code refinement with optional @@ -349,6 +355,7 @@ def optimize_code_refinement(self, request: list[AIServiceRefinerRequest]) -> li "call_sequence": self.get_next_sequence(), # Multi-language support "language": opt.language, + "rerun_trace_id": rerun_trace_id, } self.add_language_metadata(item, opt.language_version) @@ -375,7 +382,9 @@ def optimize_code_refinement(self, request: list[AIServiceRefinerRequest]) -> li console.rule() return [] - def code_repair(self, request: AIServiceCodeRepairRequest) -> OptimizedCandidate | None: + def code_repair( + self, request: AIServiceCodeRepairRequest, rerun_trace_id: str | None = None + ) -> OptimizedCandidate | None: console.rule() try: payload = { @@ -385,6 +394,7 @@ def code_repair(self, request: AIServiceCodeRepairRequest) -> OptimizedCandidate "trace_id": request.trace_id, "test_diffs": request.test_diffs, "language": request.language, + "rerun_trace_id": rerun_trace_id, } response = self.make_ai_service_request("/code_repair", payload=payload, timeout=self.timeout) except (requests.exceptions.RequestException, TypeError) as e: @@ -607,6 +617,7 @@ def generate_regression_tests( language_version: str | None = None, module_system: str | None = None, is_numerical_code: bool | None = None, + rerun_trace_id: str | None = None, ) -> tuple[str, str, str, str | None] | None: """Generate regression tests for the given function by making a request to the Django endpoint. @@ -655,6 +666,7 @@ def generate_regression_tests( "is_numerical_code": is_numerical_code, "class_name": function_to_optimize.class_name, "qualified_name": function_to_optimize.qualified_name, + "rerun_trace_id": rerun_trace_id, } self.add_language_metadata(payload, language_version, module_system) @@ -913,6 +925,34 @@ def generate_workflow_steps( logger.debug("[aiservice.py:generate_workflow_steps] Could not parse error response") return None + def prescreen_functions( + self, functions: list[dict[str, str]], trace_id: str | None = None + ) -> dict[str, dict[str, Any]] | None: + """Pre-screen functions for optimization potential using LLM assessment. + + Args: + functions: List of dicts with keys: qualified_name, source_code, language + trace_id: Optional trace ID for logging + + Returns: + Dict mapping qualified_name to {score: int, optimizable: bool, reason: str}, + or None if the call fails. + + """ + payload: dict[str, Any] = {"functions": functions, "trace_id": trace_id} + try: + response = self.make_ai_service_request("/prescreen", payload=payload, timeout=30) + except requests.exceptions.RequestException as e: + logger.debug(f"Prescreening request failed: {e}") + return None + + if response.status_code == 200: + logger.info(f"loading|Pre-screening {len(functions)} function(s) for optimization potential") + result: dict[str, dict[str, Any]] = response.json().get("functions", {}) + return result + logger.debug(f"Prescreening returned status {response.status_code}") + return None + class LocalAiServiceClient(AiServiceClient): """Client for interacting with the local AI service.""" diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index 73b15b0ad..b47496944 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -100,6 +100,8 @@ def process_pyproject_config(args: Namespace) -> Namespace: "disable_imports_sorting", "git_remote", "override_fixtures", + "max_functions", + "max_time", ] for key in supported_keys: if key in pyproject_config and ( @@ -489,5 +491,33 @@ def _build_parser() -> ArgumentParser: action="store_true", help="Subagent mode: skip all interactive prompts with sensible defaults. Designed for AI agent integrations.", ) + parser.add_argument( + "--max-functions", + type=int, + default=None, + help="Maximum number of functions to optimize per run. After ranking, only the top N functions are optimized. " + "Useful for controlling CI runtime.", + ) + parser.add_argument( + "--max-time", + type=int, + default=None, + help="Maximum total optimization time in minutes. Stops optimizing new functions after this budget is exhausted. " + "Functions already in progress will complete.", + ) + parser.add_argument( + "--prescreening", + action="store_true", + default=False, + help="Enable LLM-based pre-screening of functions before optimization. " + "Uses a fast LLM call to assess whether each function has meaningful optimization potential. " + "Filters out functions unlikely to benefit from optimization.", + ) + parser.add_argument( + "--no-prescreening", + action="store_true", + default=False, + help="Disable LLM pre-screening even when it would be auto-enabled (e.g., in CI with --all).", + ) return parser diff --git a/codeflash/discovery/diff_classifier.py b/codeflash/discovery/diff_classifier.py new file mode 100644 index 000000000..819d31762 --- /dev/null +++ b/codeflash/discovery/diff_classifier.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass +from enum import Enum +from io import StringIO +from pathlib import Path +from typing import TYPE_CHECKING + +import git +from unidiff import PatchSet + +from codeflash.cli_cmds.console import logger + +if TYPE_CHECKING: + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + + +class DiffCategory(str, Enum): + COSMETIC = "cosmetic" + TRIVIAL = "trivial" + MEANINGFUL = "meaningful" + MAJOR = "major" + + +@dataclass(frozen=True) +class FunctionDiffInfo: + category: DiffCategory + added_logic_lines: int + removed_logic_lines: int + total_changed_lines: int + is_comment_only: bool + is_whitespace_only: bool + + +# Patterns for comments across languages +_COMMENT_PATTERNS = [ + re.compile(r"^\s*#"), # Python + re.compile(r"^\s*//"), # JS/TS/Java + re.compile(r"^\s*\*"), # Multiline comment body + re.compile(r"^\s*/\*"), # Multiline comment start + re.compile(r"^\s*\*/"), # Multiline comment end + re.compile(r'^\s*"""'), # Python docstring + re.compile(r"^\s*'''"), # Python docstring +] + +# Patterns for import/require statements +_IMPORT_PATTERNS = [ + re.compile(r"^\s*(import |from \S+ import )"), # Python + re.compile(r"^\s*(const|let|var)\s+.*=\s*require\("), # JS require + re.compile(r"^\s*import\s+"), # JS/TS/Java import +] + + +def _is_comment_line(line: str) -> bool: + return any(p.match(line) for p in _COMMENT_PATTERNS) + + +def _is_import_line(line: str) -> bool: + return any(p.match(line) for p in _IMPORT_PATTERNS) + + +def _is_logic_line(line: str) -> bool: + stripped = line.strip() + if not stripped: + return False + if _is_comment_line(line): + return False + # String-only lines (just a string literal) + if stripped.startswith(('"""', "'''", '"', "'")) and stripped.endswith(('"""', "'''", '"', "'")): + return False + return True + + +def classify_function_diff(func: FunctionToOptimize, repo_directory: Path | None = None) -> FunctionDiffInfo: + """Classify the type of change made to a function based on git diff content.""" + if func.starting_line is None or func.ending_line is None: + return FunctionDiffInfo( + category=DiffCategory.MEANINGFUL, + added_logic_lines=0, + removed_logic_lines=0, + total_changed_lines=0, + is_comment_only=False, + is_whitespace_only=False, + ) + + diff_lines = _get_function_diff_lines(func, repo_directory) + if not diff_lines: + return FunctionDiffInfo( + category=DiffCategory.COSMETIC, + added_logic_lines=0, + removed_logic_lines=0, + total_changed_lines=0, + is_comment_only=False, + is_whitespace_only=True, + ) + + added_lines = [line for line in diff_lines if line.startswith("+")] + removed_lines = [line for line in diff_lines if line.startswith("-")] + total_changed = len(added_lines) + len(removed_lines) + + # Strip the +/- prefix for content analysis + added_content = [line[1:] for line in added_lines] + removed_content = [line[1:] for line in removed_lines] + + # Check if all changes are whitespace-only + added_stripped = [line.strip() for line in added_content] + removed_stripped = [line.strip() for line in removed_content] + if all(not s for s in added_stripped) and all(not s for s in removed_stripped): + return FunctionDiffInfo( + category=DiffCategory.COSMETIC, + added_logic_lines=0, + removed_logic_lines=0, + total_changed_lines=total_changed, + is_comment_only=False, + is_whitespace_only=True, + ) + + # Check if all changes are comment-only + added_logic = [line for line in added_content if _is_logic_line(line)] + removed_logic = [line for line in removed_content if _is_logic_line(line)] + + is_comment_only = len(added_logic) == 0 and len(removed_logic) == 0 + if is_comment_only: + return FunctionDiffInfo( + category=DiffCategory.COSMETIC, + added_logic_lines=0, + removed_logic_lines=0, + total_changed_lines=total_changed, + is_comment_only=True, + is_whitespace_only=False, + ) + + # Classify by logic change magnitude + logic_change_count = len(added_logic) + len(removed_logic) + + if logic_change_count <= 2: + category = DiffCategory.TRIVIAL + elif logic_change_count <= 10: + category = DiffCategory.MEANINGFUL + else: + category = DiffCategory.MAJOR + + return FunctionDiffInfo( + category=category, + added_logic_lines=len(added_logic), + removed_logic_lines=len(removed_logic), + total_changed_lines=total_changed, + is_comment_only=False, + is_whitespace_only=False, + ) + + +def _get_function_diff_lines(func: FunctionToOptimize, repo_directory: Path | None = None) -> list[str]: + """Extract diff lines that fall within a function's line range.""" + if repo_directory is None: + repo_directory = Path.cwd() + + try: + repository = git.Repo(repo_directory, search_parent_directories=True) + except git.InvalidGitRepositoryError: + return [] + + commit = repository.head.commit + try: + uni_diff_text = repository.git.diff( + commit.hexsha + "^1", commit.hexsha, ignore_blank_lines=True, ignore_space_at_eol=True + ) + except git.GitCommandError: + return [] + + patch_set = PatchSet(StringIO(uni_diff_text)) + func_start = func.starting_line or 0 + func_end = func.ending_line or 0 + result: list[str] = [] + + for patched_file in patch_set: + file_path = Path(repository.working_dir) / patched_file.path + if file_path != func.file_path: + continue + + for hunk in patched_file: + for line in hunk: + if line.is_added and line.target_line_no and func_start <= line.target_line_no <= func_end: + result.append(f"+{line.value}") + elif line.is_removed and line.source_line_no and func_start <= line.source_line_no <= func_end: + result.append(f"-{line.value}") + + return result + + +def filter_cosmetic_diff_functions( + functions: dict[Path, list[FunctionToOptimize]], repo_directory: Path | None = None +) -> tuple[dict[Path, list[FunctionToOptimize]], int]: + """Remove functions where the diff is purely cosmetic (comments/whitespace only).""" + filtered: dict[Path, list[FunctionToOptimize]] = {} + skipped_count = 0 + + for file_path, funcs in functions.items(): + kept: list[FunctionToOptimize] = [] + for func in funcs: + try: + diff_info = classify_function_diff(func, repo_directory) + except Exception: + kept.append(func) + continue + + if diff_info.category == DiffCategory.COSMETIC: + skipped_count += 1 + logger.debug( + f"Skipping {func.qualified_name} — diff is cosmetic " + f"({'comments only' if diff_info.is_comment_only else 'whitespace only'})" + ) + else: + kept.append(func) + + if kept: + filtered[file_path] = kept + + if skipped_count > 0: + logger.info(f"Diff analysis: skipped {skipped_count} function(s) with cosmetic-only changes") + + return filtered, skipped_count + + +def get_effort_for_diff(diff_info: FunctionDiffInfo) -> str | None: + """Suggest an effort level based on diff category. Returns None to use default.""" + if diff_info.category == DiffCategory.TRIVIAL: + return "low" + if diff_info.category == DiffCategory.MAJOR: + return "high" + return None diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 4d98fbde9..53db9376d 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -317,10 +317,22 @@ def get_functions_to_optimize( console.rule() ph("cli-optimizing-git-diff") functions = get_functions_within_git_diff(uncommitted_changes=False) + + # Skip functions with cosmetic-only diffs (comments/whitespace) + from codeflash.discovery.diff_classifier import filter_cosmetic_diff_functions + + functions, _ = filter_cosmetic_diff_functions(functions) filtered_modified_functions, functions_count = filter_functions( functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions ) + # Pre-screen functions by optimizability to skip trivial/unoptimizable code + from codeflash.discovery.optimizability_scorer import filter_by_optimizability + + filtered_modified_functions, prescreened_count = filter_by_optimizability(filtered_modified_functions) + if prescreened_count > 0: + functions_count -= prescreened_count + logger.info(f"!lsp|Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize") return filtered_modified_functions, functions_count, trace_file_path diff --git a/codeflash/discovery/optimizability_scorer.py b/codeflash/discovery/optimizability_scorer.py new file mode 100644 index 000000000..e121eee40 --- /dev/null +++ b/codeflash/discovery/optimizability_scorer.py @@ -0,0 +1,301 @@ +from __future__ import annotations + +import ast +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from codeflash.cli_cmds.console import logger + +if TYPE_CHECKING: + from pathlib import Path + + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + +MIN_LINES_FOR_OPTIMIZATION = 3 +DEFAULT_OPTIMIZABILITY_THRESHOLD = 0.15 + + +@dataclass(frozen=True) +class OptimizabilityScore: + function_name: str + score: float + reason: str + + @property + def is_optimizable(self) -> bool: + return self.score >= DEFAULT_OPTIMIZABILITY_THRESHOLD + + +def score_function_optimizability(func: FunctionToOptimize, source: str | None = None) -> OptimizabilityScore: + """Score a function's optimization potential using fast static analysis. + + Returns a score between 0.0 (not worth optimizing) and 1.0 (high potential). + """ + if func.starting_line is None or func.ending_line is None: + return OptimizabilityScore(func.qualified_name, 0.5, "unknown bounds") + + num_lines = func.ending_line - func.starting_line + 1 + if num_lines < MIN_LINES_FOR_OPTIMIZATION: + return OptimizabilityScore(func.qualified_name, 0.0, f"too small ({num_lines} lines)") + + if source is None: + try: + source = func.file_path.read_text(encoding="utf-8") + except OSError: + return OptimizabilityScore(func.qualified_name, 0.5, "could not read source") + + func_source = _extract_function_source(source, func.starting_line, func.ending_line) + if func_source is None: + return OptimizabilityScore(func.qualified_name, 0.5, "could not extract source") + + if func.language == "python": + return _score_python_function(func, func_source, num_lines) + if func.language in ("javascript", "typescript"): + return _score_by_heuristics(func, func_source, num_lines) + if func.language == "java": + return _score_by_heuristics(func, func_source, num_lines) + return OptimizabilityScore(func.qualified_name, 0.5, "unknown language") + + +def _extract_function_source(full_source: str, start_line: int, end_line: int) -> str | None: + lines = full_source.splitlines() + if start_line < 1 or end_line > len(lines): + return None + return "\n".join(lines[start_line - 1 : end_line]) + + +def _score_python_function(func: FunctionToOptimize, func_source: str, num_lines: int) -> OptimizabilityScore: + try: + tree = ast.parse(func_source) + except SyntaxError: + return _score_by_heuristics(func, func_source, num_lines) + + visitor = _PythonComplexityVisitor() + visitor.visit(tree) + + score = 0.0 + reasons: list[str] = [] + + # Size contribution (logarithmic, caps at ~0.3) + size_score = min(0.3, num_lines / 100) + score += size_score + + # Loop presence is a strong signal + if visitor.loop_count > 0: + loop_score = min(0.35, visitor.loop_count * 0.15) + score += loop_score + reasons.append(f"{visitor.loop_count} loop(s)") + + # Comprehension/generator expressions + if visitor.comprehension_count > 0: + score += min(0.15, visitor.comprehension_count * 0.08) + reasons.append(f"{visitor.comprehension_count} comprehension(s)") + + # Nested loops are a very strong signal + if visitor.max_loop_depth >= 2: + score += 0.2 + reasons.append(f"nested loops (depth {visitor.max_loop_depth})") + + # Recursion + if visitor.has_recursion: + score += 0.15 + reasons.append("recursive") + + # Mathematical operations (sorting, searching patterns) + if visitor.math_op_count > 0: + score += min(0.1, visitor.math_op_count * 0.03) + reasons.append("math ops") + + # Data structure operations + if visitor.collection_op_count > 0: + score += min(0.1, visitor.collection_op_count * 0.03) + reasons.append("collection ops") + + # Penalty: mostly I/O (file, network, DB) + if visitor.io_call_count > 0 and visitor.loop_count == 0: + io_ratio = visitor.io_call_count / max(1, visitor.total_call_count) + if io_ratio > 0.5: + score *= 0.3 + reasons.append(f"I/O dominated ({visitor.io_call_count} I/O calls)") + + # Penalty: simple delegation (just calls another function) + if num_lines <= 5 and visitor.total_call_count == 1 and visitor.loop_count == 0: + score *= 0.2 + reasons.append("simple delegation") + + # Penalty: only string formatting / logging + if visitor.is_mostly_string_ops: + score *= 0.3 + reasons.append("mostly string/logging ops") + + score = min(1.0, max(0.0, score)) + reason = ", ".join(reasons) if reasons else f"{num_lines} lines" + return OptimizabilityScore(func.qualified_name, score, reason) + + +def _score_by_heuristics(func: FunctionToOptimize, func_source: str, num_lines: int) -> OptimizabilityScore: + """Language-agnostic heuristic scoring using text patterns.""" + score = 0.0 + reasons: list[str] = [] + + # Size + score += min(0.3, num_lines / 100) + + # Loop keywords + loop_keywords = ("for ", "for(", "while ", "while(", ".forEach(", ".map(", ".filter(", ".reduce(") + loop_count = sum(1 for kw in loop_keywords if kw in func_source) + if loop_count > 0: + score += min(0.35, loop_count * 0.12) + reasons.append(f"loop patterns ({loop_count})") + + # Sorting/searching + sort_patterns = ("sort(", "sorted(", ".sort(", "binarySearch", "indexOf", "Collections.sort") + if any(p in func_source for p in sort_patterns): + score += 0.15 + reasons.append("sort/search ops") + + # Nested structure (indentation depth as proxy) + max_indent = max((len(line) - len(line.lstrip()) for line in func_source.splitlines() if line.strip()), default=0) + if max_indent > 16: # roughly 4+ levels of nesting + score += 0.15 + reasons.append("deep nesting") + + # Simple delegation penalty + if num_lines <= 5: + score *= 0.3 + reasons.append("very small") + + score = min(1.0, max(0.0, score)) + reason = ", ".join(reasons) if reasons else f"{num_lines} lines" + return OptimizabilityScore(func.qualified_name, score, reason) + + +class _PythonComplexityVisitor(ast.NodeVisitor): + def __init__(self) -> None: + self.loop_count = 0 + self.max_loop_depth = 0 + self.comprehension_count = 0 + self.has_recursion = False + self.math_op_count = 0 + self.collection_op_count = 0 + self.io_call_count = 0 + self.total_call_count = 0 + self.is_mostly_string_ops = False + self._current_loop_depth = 0 + self._current_func_name: str | None = None + self._string_op_count = 0 + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + if self._current_func_name is None: + self._current_func_name = node.name + self.generic_visit(node) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + if self._current_func_name is None: + self._current_func_name = node.name + self.generic_visit(node) + + def visit_For(self, node: ast.For) -> None: + self.loop_count += 1 + self._current_loop_depth += 1 + self.max_loop_depth = max(self.max_loop_depth, self._current_loop_depth) + self.generic_visit(node) + self._current_loop_depth -= 1 + + def visit_While(self, node: ast.While) -> None: + self.loop_count += 1 + self._current_loop_depth += 1 + self.max_loop_depth = max(self.max_loop_depth, self._current_loop_depth) + self.generic_visit(node) + self._current_loop_depth -= 1 + + def visit_ListComp(self, node: ast.ListComp) -> None: + self.comprehension_count += 1 + self.generic_visit(node) + + def visit_SetComp(self, node: ast.SetComp) -> None: + self.comprehension_count += 1 + self.generic_visit(node) + + def visit_DictComp(self, node: ast.DictComp) -> None: + self.comprehension_count += 1 + self.generic_visit(node) + + def visit_GeneratorExp(self, node: ast.GeneratorExp) -> None: + self.comprehension_count += 1 + self.generic_visit(node) + + def visit_Call(self, node: ast.Call) -> None: + self.total_call_count += 1 + func_name = _get_call_name(node) + if func_name: + # Recursion detection + if func_name == self._current_func_name: + self.has_recursion = True + # I/O patterns + io_names = {"open", "read", "write", "send", "recv", "connect", "execute", "fetch", "request", "urlopen"} + if func_name in io_names or func_name.startswith(("requests.", "urllib")): + self.io_call_count += 1 + # Math patterns + math_names = {"sum", "min", "max", "abs", "pow", "sqrt", "log", "exp", "ceil", "floor"} + if func_name in math_names or "numpy" in func_name or "np." in func_name or "math." in func_name: + self.math_op_count += 1 + # Collection operations + collection_names = {"sorted", "reversed", "enumerate", "zip", "filter", "map", "reduce"} + if func_name in collection_names: + self.collection_op_count += 1 + # String ops + string_names = {"format", "join", "split", "replace", "strip", "lower", "upper", "encode", "decode"} + if func_name in string_names or func_name.endswith((".format", ".join")): + self._string_op_count += 1 + + self.generic_visit(node) + + # After visiting all nodes, check if mostly string ops + if self.total_call_count > 0: + self.is_mostly_string_ops = (self._string_op_count / self.total_call_count) > 0.7 + + def visit_BinOp(self, node: ast.BinOp) -> None: + if isinstance(node.op, (ast.Mult, ast.Pow, ast.MatMult, ast.FloorDiv)): + self.math_op_count += 1 + self.generic_visit(node) + + +def _get_call_name(node: ast.Call) -> str | None: + if isinstance(node.func, ast.Name): + return node.func.id + if isinstance(node.func, ast.Attribute): + return node.func.attr + return None + + +def filter_by_optimizability( + functions: dict[Path, list[FunctionToOptimize]], threshold: float = DEFAULT_OPTIMIZABILITY_THRESHOLD +) -> tuple[dict[Path, list[FunctionToOptimize]], int]: + """Filter functions by optimizability score, returning only those above threshold.""" + filtered: dict[Path, list[FunctionToOptimize]] = {} + skipped_count = 0 + + for file_path, funcs in functions.items(): + try: + source = file_path.read_text(encoding="utf-8") + except OSError: + filtered[file_path] = funcs + continue + + kept: list[FunctionToOptimize] = [] + for func in funcs: + result = score_function_optimizability(func, source) + if result.is_optimizable: + kept.append(func) + else: + skipped_count += 1 + logger.debug(f"Skipping {func.qualified_name} (score={result.score:.2f}, reason: {result.reason})") + + if kept: + filtered[file_path] = kept + + if skipped_count > 0: + logger.info(f"Pre-screening: skipped {skipped_count} low-optimizability function(s)") + + return filtered, skipped_count diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index fef53a760..33a4b1a99 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -6,7 +6,7 @@ import time from collections import defaultdict from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient from codeflash.api.cfapi import send_completion_email @@ -79,6 +79,146 @@ def __init__(self, args: Namespace) -> None: self.original_args_and_test_cfg: tuple[Namespace, TestConfig] | None = None self.patch_files: list[Path] = [] self._cached_callee_counts: dict[tuple[Path, str], int] = {} + self._high_effort_top_n: int | None = None + + # Smart CI defaults (used when no explicit value is provided via CLI/env/config) + CI_DEFAULT_MAX_FUNCTIONS = 30 + CI_DEFAULT_MAX_TIME_MINUTES = 20 + CI_DEFAULT_HIGH_EFFORT_TOP_N = 5 + + @staticmethod + def _resolve_int_setting(args_value: int | None, env_var: str, ci_default: int | None = None) -> int | None: + """Resolve a setting with precedence: CLI arg > env var > CI default. + + Returns None if no value found at any level (meaning no limit). + """ + # CLI arg takes priority (already set from CLI or pyproject.toml) + if args_value is not None: + return args_value + # Environment variable next + env_val = os.environ.get(env_var) + if env_val is not None: + try: + return int(env_val) + except ValueError: + logger.warning(f"Invalid {env_var}={env_val!r} (expected integer), ignoring") + # CI auto-default last (only if running in CI) + if ci_default is not None and env_utils.is_ci(): + return ci_default + return None + + def _apply_ci_defaults(self) -> None: + """Resolve max_functions, max_time, and high_effort_top_n with full precedence chain. + + Precedence: CLI arg > env var > pyproject.toml (already merged into args) > CI auto-default. + Env vars: CODEFLASH_MAX_FUNCTIONS, CODEFLASH_MAX_TIME, CODEFLASH_HIGH_EFFORT_TOP_N. + """ + resolved_max_functions = self._resolve_int_setting( + getattr(self.args, "max_functions", None), + "CODEFLASH_MAX_FUNCTIONS", + ci_default=self.CI_DEFAULT_MAX_FUNCTIONS, + ) + resolved_max_time = self._resolve_int_setting( + getattr(self.args, "max_time", None), "CODEFLASH_MAX_TIME", ci_default=self.CI_DEFAULT_MAX_TIME_MINUTES + ) + # HIGH_EFFORT_TOP_N is not a CLI arg — only configurable via env var + self._high_effort_top_n = self._resolve_int_setting( + None, "CODEFLASH_HIGH_EFFORT_TOP_N", ci_default=self.CI_DEFAULT_HIGH_EFFORT_TOP_N + ) + + self.args.max_functions = resolved_max_functions + self.args.max_time = resolved_max_time + + applied: list[str] = [] + if resolved_max_functions is not None: + applied.append(f"max-functions={resolved_max_functions}") + if resolved_max_time is not None: + applied.append(f"max-time={resolved_max_time}m") + if self._high_effort_top_n is not None: + applied.append(f"high-effort-top-n={self._high_effort_top_n}") + + if applied and env_utils.is_ci(): + logger.info(f"CI mode: active limits ({', '.join(applied)})") + + # Maximum batch size for a single prescreening LLM call + PRESCREENING_BATCH_SIZE = 20 + + def _should_prescreen(self) -> bool: + if getattr(self.args, "no_prescreening", False): + return False + if getattr(self.args, "prescreening", False): + return True + # Auto-enable in CI when running --all (the most expensive mode) + if env_utils.is_ci() and getattr(self.args, "all", None) is not None: + return True + return False + + def _run_prescreening( + self, ranked_functions: list[tuple[Path, FunctionToOptimize]] + ) -> list[tuple[Path, FunctionToOptimize]]: + if not self._should_prescreen() or not ranked_functions: + return ranked_functions + + console.rule() + logger.info(f"loading|Pre-screening {len(ranked_functions)} function(s) with LLM assessment...") + + # Read source code for each function (batch by file for efficiency) + source_cache: dict[Path, str] = {} + function_inputs: list[dict[str, str]] = [] + for file_path, func in ranked_functions: + if file_path not in source_cache: + try: + source_cache[file_path] = file_path.read_text(encoding="utf-8") + except OSError: + continue + + source = source_cache[file_path] + if func.starting_line and func.ending_line: + lines = source.splitlines() + func_source = "\n".join(lines[func.starting_line - 1 : func.ending_line]) + else: + func_source = source # fallback: send whole file + + function_inputs.append( + {"qualified_name": func.qualified_name, "source_code": func_source, "language": func.language} + ) + + if not function_inputs: + return ranked_functions + + # Batch the prescreening calls + all_results: dict[str, dict[str, Any]] = {} + for batch_start in range(0, len(function_inputs), self.PRESCREENING_BATCH_SIZE): + batch = function_inputs[batch_start : batch_start + self.PRESCREENING_BATCH_SIZE] + result = self.aiservice_client.prescreen_functions(batch) + if result is not None: + all_results.update(result) + + if not all_results: + logger.debug("Prescreening returned no results, keeping all functions") + console.rule() + return ranked_functions + + # Filter based on prescreening results + kept: list[tuple[Path, FunctionToOptimize]] = [] + skipped = 0 + for file_path, func in ranked_functions: + func_result = all_results.get(func.qualified_name) + if func_result is None or func_result.get("optimizable", True): + kept.append((file_path, func)) + else: + skipped += 1 + reason = func_result.get("reason", "low optimization potential") + score = func_result.get("score", "?") + logger.debug(f"Prescreening filtered: {func.qualified_name} (score={score}, reason: {reason})") + + if skipped > 0: + logger.info( + f"LLM pre-screening: kept {len(kept)} of {len(ranked_functions)} functions " + f"(filtered {skipped} with low optimization potential)" + ) + console.rule() + return kept def run_benchmarks( self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], num_optimizable_functions: int @@ -493,6 +633,9 @@ def run(self) -> None: logger.warning("PR is in draft mode, skipping optimization") return + # Smart CI defaults: apply conservative limits when running in CI + self._apply_ci_defaults() + if self.args.worktree: result = self.worktree_mode() if result.is_failure(): @@ -589,6 +732,24 @@ def run(self) -> None: globally_ranked_functions = self.rank_all_functions_globally( file_to_funcs_to_optimize, trace_file_path, call_graph=resolver, test_count_cache=test_count_cache ) + + # LLM-based pre-screening (if enabled) + globally_ranked_functions = self._run_prescreening(globally_ranked_functions) + + # Apply --max-functions limit after ranking and prescreening + max_functions = getattr(self.args, "max_functions", None) + if max_functions is not None and len(globally_ranked_functions) > max_functions: + skipped = len(globally_ranked_functions) - max_functions + logger.info( + f"--max-functions={max_functions}: optimizing top {max_functions} of " + f"{len(globally_ranked_functions)} ranked functions (skipping {skipped})" + ) + globally_ranked_functions = globally_ranked_functions[:max_functions] + + # Track start time for --max-time budget enforcement + optimization_start_time = time.monotonic() + max_time_minutes = getattr(self.args, "max_time", None) + # Cache for module preparation (avoid re-parsing same files) prepared_modules: dict[Path, tuple[dict[Path, ValidCode], ast.Module | None]] = {} @@ -602,6 +763,17 @@ def run(self) -> None: # Optimize functions in globally ranked order for i, (original_module_path, function_to_optimize) in enumerate(globally_ranked_functions): + # Check --max-time budget before starting a new function + if max_time_minutes is not None: + elapsed_minutes = (time.monotonic() - optimization_start_time) / 60 + if elapsed_minutes >= max_time_minutes: + remaining = len(globally_ranked_functions) - i + logger.info( + f"--max-time={max_time_minutes}m budget exhausted after {elapsed_minutes:.1f}m. " + f"Optimized {i} functions, skipping remaining {remaining}." + ) + break + # Prepare module if not already cached if original_module_path not in prepared_modules: module_prep_result = self.prepare_module_for_optimization(original_module_path) @@ -622,11 +794,14 @@ def run(self) -> None: test_suffix = f", {test_count} tests" if test_count else "" effort_override: str | None = None - if i < HIGH_EFFORT_TOP_N and self.args.effort == EffortLevel.MEDIUM.value: + high_effort_limit = ( + self._high_effort_top_n if self._high_effort_top_n is not None else HIGH_EFFORT_TOP_N + ) + if i < high_effort_limit and self.args.effort == EffortLevel.MEDIUM.value: effort_override = EffortLevel.HIGH.value logger.debug( f"Escalating effort for {function_to_optimize.qualified_name} from medium to high" - f" (top {HIGH_EFFORT_TOP_N} ranked)" + f" (top {high_effort_limit} ranked)" ) logger.info( diff --git a/tests/test_diff_classifier.py b/tests/test_diff_classifier.py new file mode 100644 index 000000000..939a5a5bc --- /dev/null +++ b/tests/test_diff_classifier.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from codeflash.discovery.diff_classifier import ( + DiffCategory, + FunctionDiffInfo, + _is_comment_line, + _is_logic_line, + get_effort_for_diff, +) + + +class TestIsCommentLine: + def test_python_comment(self) -> None: + assert _is_comment_line(" # this is a comment") + assert _is_comment_line("# top-level comment") + + def test_js_comment(self) -> None: + assert _is_comment_line(" // this is a comment") + assert _is_comment_line("// top-level") + + def test_multiline_comment(self) -> None: + assert _is_comment_line(" * middle of block comment") + assert _is_comment_line(" /* start of block comment") + assert _is_comment_line(" */ end of block comment") + + def test_docstring(self) -> None: + assert _is_comment_line(' """docstring"""') + assert _is_comment_line(" '''docstring'''") + + def test_not_a_comment(self) -> None: + assert not _is_comment_line(" x = 1") + assert not _is_comment_line(" return x") + assert not _is_comment_line(" for i in range(10):") + + +class TestIsLogicLine: + def test_logic_lines(self) -> None: + assert _is_logic_line(" x = 1") + assert _is_logic_line(" return x + y") + assert _is_logic_line(" if condition:") + + def test_not_logic_lines(self) -> None: + assert not _is_logic_line("") + assert not _is_logic_line(" ") + assert not _is_logic_line(" # comment") + assert not _is_logic_line(" // comment") + + +class TestDiffCategory: + def test_cosmetic_whitespace(self) -> None: + info = FunctionDiffInfo( + category=DiffCategory.COSMETIC, + added_logic_lines=0, + removed_logic_lines=0, + total_changed_lines=3, + is_comment_only=False, + is_whitespace_only=True, + ) + assert info.category == DiffCategory.COSMETIC + assert info.is_whitespace_only + + def test_cosmetic_comments(self) -> None: + info = FunctionDiffInfo( + category=DiffCategory.COSMETIC, + added_logic_lines=0, + removed_logic_lines=0, + total_changed_lines=5, + is_comment_only=True, + is_whitespace_only=False, + ) + assert info.category == DiffCategory.COSMETIC + assert info.is_comment_only + + def test_trivial(self) -> None: + info = FunctionDiffInfo( + category=DiffCategory.TRIVIAL, + added_logic_lines=1, + removed_logic_lines=1, + total_changed_lines=2, + is_comment_only=False, + is_whitespace_only=False, + ) + assert info.category == DiffCategory.TRIVIAL + + def test_meaningful(self) -> None: + info = FunctionDiffInfo( + category=DiffCategory.MEANINGFUL, + added_logic_lines=5, + removed_logic_lines=3, + total_changed_lines=8, + is_comment_only=False, + is_whitespace_only=False, + ) + assert info.category == DiffCategory.MEANINGFUL + + def test_major(self) -> None: + info = FunctionDiffInfo( + category=DiffCategory.MAJOR, + added_logic_lines=15, + removed_logic_lines=10, + total_changed_lines=25, + is_comment_only=False, + is_whitespace_only=False, + ) + assert info.category == DiffCategory.MAJOR + + +class TestGetEffortForDiff: + def test_trivial_gets_low_effort(self) -> None: + info = FunctionDiffInfo(DiffCategory.TRIVIAL, 1, 0, 1, False, False) + assert get_effort_for_diff(info) == "low" + + def test_major_gets_high_effort(self) -> None: + info = FunctionDiffInfo(DiffCategory.MAJOR, 20, 10, 30, False, False) + assert get_effort_for_diff(info) == "high" + + def test_meaningful_gets_default(self) -> None: + info = FunctionDiffInfo(DiffCategory.MEANINGFUL, 5, 3, 8, False, False) + assert get_effort_for_diff(info) is None + + def test_cosmetic_gets_default(self) -> None: + info = FunctionDiffInfo(DiffCategory.COSMETIC, 0, 0, 2, True, False) + assert get_effort_for_diff(info) is None diff --git a/tests/test_optimizability_scorer.py b/tests/test_optimizability_scorer.py new file mode 100644 index 000000000..f5c1c983a --- /dev/null +++ b/tests/test_optimizability_scorer.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from codeflash.discovery.optimizability_scorer import ( + DEFAULT_OPTIMIZABILITY_THRESHOLD, + OptimizabilityScore, + filter_by_optimizability, + score_function_optimizability, +) +from codeflash.models.function_types import FunctionToOptimize + + +def _make_func( + name: str = "my_func", + start: int = 1, + end: int = 20, + language: str = "python", + file_path: Path | None = None, +) -> FunctionToOptimize: + return FunctionToOptimize( + function_name=name, + file_path=file_path or Path("/fake/module.py"), + starting_line=start, + ending_line=end, + language=language, + ) + + +class TestOptimizabilityScore: + def test_is_optimizable_above_threshold(self) -> None: + score = OptimizabilityScore("f", 0.5, "loops") + assert score.is_optimizable + + def test_is_not_optimizable_below_threshold(self) -> None: + score = OptimizabilityScore("f", 0.0, "too small") + assert not score.is_optimizable + + def test_threshold_boundary(self) -> None: + score = OptimizabilityScore("f", DEFAULT_OPTIMIZABILITY_THRESHOLD, "boundary") + assert score.is_optimizable + + +class TestScorePythonFunction: + def test_tiny_function_scores_zero(self) -> None: + source = "def f():\n return 1\n" + func = _make_func(start=1, end=2) + result = score_function_optimizability(func, source) + assert result.score == 0.0 + assert "too small" in result.reason + + def test_function_with_loops_scores_high(self) -> None: + source = "\n".join([ + "def process(data):", + " result = []", + " for item in data:", + " for sub in item:", + " result.append(sub * 2)", + " return result", + "", + "", + "", + "", + ]) + func = _make_func(start=1, end=6) + result = score_function_optimizability(func, source) + assert result.score >= 0.3 + assert "loop" in result.reason or "nested" in result.reason + + def test_simple_delegation_scores_low(self) -> None: + source = "\n".join([ + "def wrapper(x):", + " return other_func(x)", + ]) + func = _make_func(start=1, end=2) + result = score_function_optimizability(func, source) + assert result.score < DEFAULT_OPTIMIZABILITY_THRESHOLD + + def test_comprehension_contributes(self) -> None: + source = "\n".join([ + "def transform(data):", + " a = [x * 2 for x in data]", + " b = {k: v for k, v in pairs}", + " return a, b", + "", + "", + ]) + func = _make_func(start=1, end=4) + result = score_function_optimizability(func, source) + assert result.score > 0 + + def test_recursive_function_scores_well(self) -> None: + source = "\n".join([ + "def fibonacci(n):", + " if n <= 1:", + " return n", + " return fibonacci(n - 1) + fibonacci(n - 2)", + "", + "", + ]) + func = _make_func(name="fibonacci", start=1, end=4) + result = score_function_optimizability(func, source) + assert result.score >= DEFAULT_OPTIMIZABILITY_THRESHOLD + assert "recursive" in result.reason + + def test_large_function_gets_size_bonus(self) -> None: + lines = ["def big_func():"] + for i in range(50): + lines.append(f" x_{i} = {i}") + lines.append(" return x_0") + source = "\n".join(lines) + func = _make_func(start=1, end=52) + result = score_function_optimizability(func, source) + assert result.score > 0.1 + + def test_unknown_bounds_gets_neutral_score(self) -> None: + func = _make_func(start=None, end=None) + result = score_function_optimizability(func, "def f(): pass") + assert result.score == 0.5 + + +class TestScoreByHeuristics: + def test_js_function_with_loops(self) -> None: + source = "\n".join([ + "function processData(arr) {", + " const result = [];", + " for (let i = 0; i < arr.length; i++) {", + " result.push(arr[i] * 2);", + " }", + " return result;", + "", + "", + ]) + func = _make_func(start=1, end=6, language="javascript") + result = score_function_optimizability(func, source) + assert result.score >= DEFAULT_OPTIMIZABILITY_THRESHOLD + + def test_js_tiny_function_scores_low(self) -> None: + source = "function getId() { return this.id; }" + func = _make_func(start=1, end=1, language="javascript") + result = score_function_optimizability(func, source) + assert result.score == 0.0 + + +class TestFilterByOptimizability: + def test_filters_low_score_functions(self, tmp_path: Path) -> None: + # Create a file with a tiny function and a complex function + source = "\n".join([ + "def tiny():", + " return 1", + "", + "def complex_func(data):", + " result = []", + " for item in data:", + " for sub in item:", + " result.append(sub * 2)", + " return result", + ]) + file = tmp_path / "module.py" + file.write_text(source, encoding="utf-8") + + tiny = _make_func(name="tiny", start=1, end=2, file_path=file) + complex_fn = _make_func(name="complex_func", start=4, end=9, file_path=file) + + functions = {file: [tiny, complex_fn]} + filtered, skipped = filter_by_optimizability(functions) + assert skipped >= 1 + # complex_func should survive + assert any(f.function_name == "complex_func" for f in filtered.get(file, [])) + + def test_empty_input(self) -> None: + filtered, skipped = filter_by_optimizability({}) + assert filtered == {} + assert skipped == 0