diff --git a/codeflash/languages/java/remove_asserts.py b/codeflash/languages/java/remove_asserts.py index 3289a9568..34908d89d 100644 --- a/codeflash/languages/java/remove_asserts.py +++ b/codeflash/languages/java/remove_asserts.py @@ -213,6 +213,9 @@ def __init__( self._CHAR_LITERAL_RE = re.compile(r"^'.'$|^'\\.'$") self._cast_re = re.compile(r"^\((\w+)\)") + # Cache for inferred types to avoid repeated expensive inference work + self._type_infer_cache: dict[tuple[str, str], str] = {} + def transform(self, source: str) -> str: """Remove assertions from source code, preserving target function calls. @@ -928,7 +931,13 @@ def _infer_return_type(self, assertion: AssertionMatch) -> str: # For assertEquals/assertNotEquals/assertSame, try to infer from the expected literal if method in JUNIT5_VALUE_ASSERTIONS: - return self._infer_type_from_assertion_args(assertion.original_text, method) + key = (assertion.original_text, method) + cached = self._type_infer_cache.get(key) + if cached is not None: + return cached + inferred = self._infer_type_from_assertion_args(assertion.original_text, method) + self._type_infer_cache[key] = inferred + return inferred # For fluent assertions (assertThat), type inference is harder — keep Object return "Object"