diff --git a/codeflash/languages/java/parser.py b/codeflash/languages/java/parser.py index 72a530179..684bb21e3 100644 --- a/codeflash/languages/java/parser.py +++ b/codeflash/languages/java/parser.py @@ -111,6 +111,11 @@ def __init__(self) -> None: """Initialize the Java analyzer.""" self._parser: Parser | None = None + # Small cache mapping source bytes -> parsed Tree to avoid repeated parsing + # for identical source content. This helps when the same source is queried + # multiple times by callers (common in the codebase). + self._tree_cache: dict[bytes, Tree] = {} + @property def parser(self) -> Parser: """Get the parser, creating it lazily.""" @@ -159,8 +164,8 @@ def find_methods( List of JavaMethodNode objects describing found methods. """ - source_bytes = source.encode("utf8") - tree = self.parse(source_bytes) + # Use cached parse tree when possible to avoid repeated expensive parse + tree, source_bytes = self._get_cached_tree(source) methods: list[JavaMethodNode] = [] self._walk_tree_for_methods( @@ -314,8 +319,7 @@ def find_classes(self, source: str) -> list[JavaClassNode]: List of JavaClassNode objects. """ - source_bytes = source.encode("utf8") - tree = self.parse(source_bytes) + tree, source_bytes = self._get_cached_tree(source) classes: list[JavaClassNode] = [] self._walk_tree_for_classes(tree.root_node, source_bytes, classes, is_inner=False) @@ -479,8 +483,7 @@ def find_fields(self, source: str, class_name: str | None = None) -> list[JavaFi List of JavaFieldInfo objects. """ - source_bytes = source.encode("utf8") - tree = self.parse(source_bytes) + tree, source_bytes = self._get_cached_tree(source) fields: list[JavaFieldInfo] = [] self._walk_tree_for_fields(tree.root_node, source_bytes, fields, current_class=None, target_class=class_name) @@ -678,6 +681,24 @@ def get_package_name(self, source: str) -> str | None: return None + def _get_cached_tree(self, source: str | bytes) -> tuple[Tree, bytes]: + """Return a cached parse tree for source or parse and cache it. + + This avoids reparsing identical source multiple times while the analyzer + instance lives, which is a major performance win when callers query + classes/fields/methods repeatedly for the same source. + """ + if isinstance(source, str): + source_bytes = source.encode("utf8") + else: + source_bytes = source + tree = self._tree_cache.get(source_bytes) + if tree is None: + tree = self.parse(source_bytes) + # Cache the tree for future calls + self._tree_cache[source_bytes] = tree + return tree, source_bytes + def get_java_analyzer() -> JavaAnalyzer: """Get a JavaAnalyzer instance. diff --git a/codeflash/languages/java/replacement.py b/codeflash/languages/java/replacement.py index 92ddd44e2..d561b097f 100644 --- a/codeflash/languages/java/replacement.py +++ b/codeflash/languages/java/replacement.py @@ -144,61 +144,63 @@ def _insert_class_members( class_indent = _get_indentation(lines[class_line]) if class_line < len(lines) else "" member_indent = class_indent + " " - result = source + result_bytes = source_bytes + + # Compute original body insertion points once + body_start = body_node.start_byte + body_end = body_node.end_byte + + # We'll keep track of a byte-offset delta as we modify the bytes so subsequent + # insertions can use adjusted positions without reparsing. + delta = 0 + + # Insert fields at the beginning of the class body (after opening brace) # Insert fields at the beginning of the class body (after opening brace) if fields: - # Re-parse to get current positions - classes = analyzer.find_classes(result) - for cls in classes: - if cls.name == class_name: - body_node = cls.node.child_by_field_name("body") - break + insert_point = body_start + 1 + delta # After opening brace + + # Build field text using list to avoid repeated string concatenation + field_parts: list[str] = [] + field_parts.append("\n") + for field in fields: + field_lines = field.strip().splitlines(keepends=True) + indented_field = _apply_indentation(field_lines, member_indent) + field_parts.append(indented_field) + if not indented_field.endswith("\n"): + field_parts.append("\n") + field_text = "".join(field_parts) + field_bytes = field_text.encode("utf8") - if body_node: - result_bytes = result.encode("utf8") - insert_point = body_node.start_byte + 1 # After opening brace + before = result_bytes[:insert_point] + after = result_bytes[insert_point:] + result_bytes = before + field_bytes + after - # Format fields - field_text = "\n" - for field in fields: - field_lines = field.strip().splitlines(keepends=True) - indented_field = _apply_indentation(field_lines, member_indent) - field_text += indented_field - if not indented_field.endswith("\n"): - field_text += "\n" + delta += len(field_bytes) # Adjust for next insertion(s) - before = result_bytes[:insert_point] - after = result_bytes[insert_point:] - result = (before + field_text.encode("utf8") + after).decode("utf8") + # Insert methods at the end of the class body (before closing brace) # Insert methods at the end of the class body (before closing brace) if methods: - # Re-parse to get current positions - classes = analyzer.find_classes(result) - for cls in classes: - if cls.name == class_name: - body_node = cls.node.child_by_field_name("body") - break - - if body_node: - result_bytes = result.encode("utf8") - insert_point = body_node.end_byte - 1 # Before closing brace + insert_point = body_end - 1 + delta # Before closing brace, adjust by delta - # Format methods - method_text = "\n" - for method in methods: - method_lines = method.strip().splitlines(keepends=True) - indented_method = _apply_indentation(method_lines, member_indent) - method_text += indented_method - if not indented_method.endswith("\n"): - method_text += "\n" - - before = result_bytes[:insert_point] - after = result_bytes[insert_point:] - result = (before + method_text.encode("utf8") + after).decode("utf8") - - return result + # Build method text efficiently + method_parts: list[str] = [] + method_parts.append("\n") + for method in methods: + method_lines = method.strip().splitlines(keepends=True) + indented_method = _apply_indentation(method_lines, member_indent) + method_parts.append(indented_method) + if not indented_method.endswith("\n"): + method_parts.append("\n") + method_text = "".join(method_parts) + method_bytes = method_text.encode("utf8") + + before = result_bytes[:insert_point] + after = result_bytes[insert_point:] + result_bytes = before + method_bytes + after + + return result_bytes.decode("utf8") def replace_function(