Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 62 additions & 8 deletions datafast/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# LiteLLM
import litellm
from litellm.exceptions import RateLimitError
from litellm.utils import ModelResponse

# Internal imports
Expand Down Expand Up @@ -116,15 +117,27 @@ def _respect_rate_limit(self) -> None:
# Keep only timestamps within the last minute
self._request_timestamps = [
ts for ts in self._request_timestamps if current - ts < 60]
if len(self._request_timestamps) < self.rpm_limit:

# Be more conservative - wait if we're at 90% of the limit
conservative_limit = max(1, int(self.rpm_limit * 0.9))

if len(self._request_timestamps) < conservative_limit:
return

# Need to wait until the earliest request is outside the 60-second window
earliest = self._request_timestamps[0]
# Add a 1s margin to avoid accidental rate limit exceedance
sleep_time = 61 - (current - earliest)
# Add a 2s margin to avoid accidental rate limit exceedance
sleep_time = 62 - (current - earliest)
if sleep_time > 0:
logger.warning(f"Rate limit reached | Waiting {sleep_time:.1f}s")
logger.warning(
f"Rate limit approaching | Requests: {len(self._request_timestamps)}/{self.rpm_limit} | "
f"Waiting {sleep_time:.1f}s"
)
time.sleep(sleep_time)
# Clean up old timestamps after waiting
current = time.monotonic()
self._request_timestamps = [
ts for ts in self._request_timestamps if current - ts < 60]

@staticmethod
def _strip_code_fences(content: str) -> str:
Expand Down Expand Up @@ -258,9 +271,33 @@ def generate(
if response_format is not None:
completion_params["response_format"] = response_format

# Call LiteLLM completion with batch messages
response: list[ModelResponse] = litellm.batch_completion(
**completion_params)
# Call LiteLLM completion with batch messages - retry on rate limit
max_retries = 3
retry_delay = 5 # Start with 5 seconds
response = None

for attempt in range(max_retries):
try:
response: list[ModelResponse] = litellm.batch_completion(
**completion_params)
break # Success, exit retry loop
except RateLimitError as e:
if attempt < max_retries - 1:
wait_time = retry_delay * (2 ** attempt) # Exponential backoff
logger.warning(
f"Rate limit hit | Provider: {self.provider_name} | Model: {self.model_id} | "
f"Attempt {attempt + 1}/{max_retries} | Waiting {wait_time}s before retry"
)
time.sleep(wait_time)
else:
logger.error(
f"Rate limit exceeded after {max_retries} attempts | "
f"Provider: {self.provider_name} | Model: {self.model_id}"
)
raise

if response is None:
raise RuntimeError("Failed to get response after retries")

# Record timestamp for rate limiting (one timestamp per batch item)
if self.rpm_limit is not None:
Expand All @@ -270,7 +307,24 @@ def generate(

# Extract content from each response
results = []
for one_response in response:
for idx, one_response in enumerate(response):
if isinstance(one_response, Exception):
if isinstance(one_response, RateLimitError):
logger.warning(
"Rate limit error in batch item | Provider: %s | Model: %s | Item: %d",
self.provider_name,
self.model_id,
idx,
)
raise RuntimeError(
f"Batch item {idx} failed during generation: {one_response}"
) from one_response

if not getattr(one_response, "choices", None):
raise RuntimeError(
f"Unexpected response type from LiteLLM batch completion at item {idx}: {type(one_response).__name__}"
)

content = one_response.choices[0].message.content

if response_format is not None:
Expand Down
Loading