Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions classify-extract.py → classsify_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,18 @@

import argparse
import csv
import logging
import sys
from datetime import datetime
from pathlib import Path

from src.preprocessing.pdf_text_extraction import extract_text_from_pdf
from src.model.pdf_classifier import load_classifier, classify_text
from src.llm.llm_text import extract_key_sections
from src.llm.llm_client import extract_metrics_from_text, save_extraction_result
from src.utils.logger import setup_logging

log = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -73,12 +78,14 @@ def run_pipeline(
pdf_paths = sorted(input_path.glob("*.pdf"))
if not pdf_paths:
print(f"[ERROR] No PDF files found in directory: {input_path}", file=sys.stderr)
log.error("No PDF files found in directory: %s", input_path)
sys.exit(1)
print(f"[INFO] Found {len(pdf_paths)} PDF(s) in {input_path}", file=sys.stderr)
elif input_path.is_file() and input_path.suffix.lower() == ".pdf":
pdf_paths = [input_path]
else:
print(f"[ERROR] Input must be a .pdf file or a directory of PDFs: {input_path}", file=sys.stderr)
log.error("Input must be a .pdf file or a directory of PDFs: %s", input_path)
sys.exit(1)

# ── Load classifier once (avoid re-reading model artifacts per file) ──
Expand All @@ -87,6 +94,7 @@ def run_pipeline(
clf_model, vectorizer, encoder = load_classifier(model_dir)
except FileNotFoundError as e:
print(f"[ERROR] {e}", file=sys.stderr)
log.critical("Classifier artifacts not found: %s", e)
sys.exit(1)
print("[INFO] Classifier loaded.", file=sys.stderr)

Expand All @@ -111,24 +119,26 @@ def run_pipeline(
"fraction_feeding": "",
}

# ── Step 1: Extract text ─────────────────
# ── Step 1: Extract text ──────────────────────────────────────────
try:
original_text = extract_text_from_pdf(str(pdf_path))
except Exception as e:
print(f" [ERROR] Text extraction failed: {e}", file=sys.stderr)
log.error("Text extraction failed for %s: %s", pdf_path.name, e)
row["extraction_status"] = "text_extraction_failed"
summary_rows.append(row)
continue

if not original_text.strip():
print(f" [WARN] No text extracted from {pdf_path.name}. Skipping.", file=sys.stderr)
log.warning("No text extracted from %s — skipping.", pdf_path.name)
row["extraction_status"] = "empty_text"
summary_rows.append(row)
continue

print(f" [INFO] Text size: {len(original_text)} chars", file=sys.stderr)

# ── Step 2: Classify ──────────────────────────
# ── Step 2: Classify ──────────────────────────────────────────────
label, confidence, pred_prob = classify_text(
text=original_text,
model=clf_model,
Expand All @@ -142,25 +152,21 @@ def run_pipeline(
row["confidence"] = f"{confidence:.4f}"
row["pred_prob"] = f"{pred_prob:.4f}"

# ── Step 3: Extract ─────────────────
# ── Step 3: Extract ───────────────────────────────────────────────
if label == "useful":
print(f" [INFO] Running LLM extraction...", file=sys.stderr)

# Trim text to budget using section-priority logic (llm_text.py)
text_for_llm = original_text
if len(text_for_llm) > max_chars:
text_for_llm = extract_key_sections(text_for_llm, max_chars)
print(f" [INFO] Text trimmed to {len(text_for_llm)} chars (budget {max_chars})", file=sys.stderr)

try:
# LLM call (llm_client.py)
metrics = extract_metrics_from_text(
text=text_for_llm,
model=llm_model,
num_ctx=num_ctx,
)

# Resolve source pages + save JSON (llm_client.py)
result = save_extraction_result(
metrics=metrics,
source_file=pdf_path,
Expand All @@ -180,6 +186,7 @@ def run_pipeline(

except Exception as e:
print(f" [ERROR] LLM extraction failed: {e}", file=sys.stderr)
log.error("LLM extraction failed for %s: %s", pdf_path.name, e)
row["extraction_status"] = "extraction_failed"

else:
Expand All @@ -189,12 +196,11 @@ def run_pipeline(
summary_rows.append(row)

# ── Write summary CSV ─────────────────────────────────────────────────
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
summaries_dir = output_dir / "summaries"
summaries_dir.mkdir(parents=True, exist_ok=True)
summary_path = summaries_dir / f"pipeline_summary_{timestamp}.csv"

fieldnames = [
"filename", "classification", "confidence", "pred_prob",
"extraction_status", "species_name", "study_location", "study_date",
Expand Down Expand Up @@ -223,6 +229,9 @@ def run_pipeline(
print(f" Summary CSV : {summary_path}", file=sys.stderr)
print("=" * 50, file=sys.stderr)

if error_count > 0:
log.warning("Pipeline finished with %d error(s). See logs/fracfeed.log for details.", error_count)


# ---------------------------------------------------------------------------
# CLI entry point
Expand Down Expand Up @@ -295,9 +304,13 @@ def main():

args = parser.parse_args()

# Configure persistent logging for this process — one call covers all modules
setup_logging()

input_path = Path(args.input)
if not input_path.exists():
print(f"[ERROR] Input path not found: {input_path}", file=sys.stderr)
log.error("Input path not found: %s", input_path)
sys.exit(1)

run_pipeline(
Expand Down
37 changes: 15 additions & 22 deletions src/llm/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,14 @@
python llm_client.py path/to/file.txt
python llm_client.py path/to/file.pdf --model llama3.1:8b
python llm_client.py path/to/file.txt --output-dir results/

This script uses Ollama to extract structured data from predator diet surveys.
It can read PDFs directly (with automatic OCR for scanned pages) or preprocessed
text files. Extracted data includes species name, study date, location, and
stomach content metrics.
"""

import argparse
import json
import sys
import logging
import re
import sys
from pathlib import Path
from typing import Optional

from ollama import chat

Expand All @@ -30,6 +25,9 @@

from src.llm.models import PredatorDietMetrics
from src.llm.llm_text import extract_key_sections, load_document
from src.utils.logger import setup_logging

log = logging.getLogger(__name__)


def extract_metrics_from_text(
Expand Down Expand Up @@ -117,26 +115,17 @@ def save_extraction_result(
) -> dict:
"""Resolve source page numbers and save extraction results to JSON.

Looks for each extracted field value in ``original_text`` and records which
PDF page(s) the values were found on (using ``[PAGE N]`` markers). The
result dict and the JSON file both include a ``source_pages`` list.

Args:
metrics: Populated PredatorDietMetrics object returned by
:func:`extract_metrics_from_text`.
source_file: Original PDF/text path — used to name the output file
and to populate the ``source_file`` field in the JSON.
original_text: Full, un-truncated extracted text (with ``[PAGE N]``
markers) used for page-number resolution.
metrics: Populated PredatorDietMetrics object.
source_file: Original PDF/text path.
original_text: Full un-truncated extracted text (with [PAGE N] markers).
output_dir: Directory where the JSON result file will be written.

Returns:
The complete result dict that was written to disk, including
``source_file``, ``file_type``, and ``metrics`` (with ``source_pages``).
The complete result dict written to disk.
"""
metrics_dict = metrics.model_dump()

# Resolve which page(s) each extracted value came from
_skip_fields = {"fraction_feeding", "source_pages"}
source_pages: set[int] = set()
for field_name, value in metrics_dict.items():
Expand Down Expand Up @@ -171,26 +160,29 @@ def main():
parser.add_argument("input_file", type=str, help="Path to the input file (.pdf or .txt)")
parser.add_argument("--model", type=str, default="llama3.1:8b", help="Ollama model to use (default: llama3.1:8b)")
parser.add_argument("--output-dir", type=str, default="data/results", help="Output directory for JSON results (default: data/results/metrics)")
parser.add_argument("--max-chars", type=int, default=12000, help="Maximum characters of text to send to the model (default: 12000). Reduce if you hit CUDA/OOM errors.")
parser.add_argument("--max-chars", type=int, default=48000, help="Maximum characters of text to send to the model (default: 48000). Reduce if you hit CUDA/OOM errors.")
parser.add_argument("--num-ctx", type=int, default=4096, help="Context window size for the model (default: 4096). Lower values use less memory.")

args = parser.parse_args()

setup_logging()

input_path = Path(args.input_file)
if not input_path.exists():
print(f"[ERROR] File not found: {input_path}", file=sys.stderr)
log.error("File not found: %s", input_path)
sys.exit(1)

print(f"Processing {input_path.name}...", file=sys.stderr)
try:
original_text = load_document(input_path)
except Exception as e:
print(f"[ERROR] Failed to load file: {e}", file=sys.stderr)
log.error("Failed to load file %s: %s", input_path, e)
sys.exit(1)

print(f"[INFO] Text size: {len(original_text)} chars", file=sys.stderr)

# Trim to budget if needed
text = original_text
if len(text) > args.max_chars:
text = extract_key_sections(text, args.max_chars)
Expand All @@ -201,6 +193,7 @@ def main():
metrics = extract_metrics_from_text(text, model=args.model, num_ctx=args.num_ctx)
except Exception as e:
print(f"[ERROR] Extraction failed: {e}", file=sys.stderr)
log.error("Metric extraction failed for %s: %s", input_path.name, e)
sys.exit(1)

result = save_extraction_result(
Expand Down
14 changes: 6 additions & 8 deletions src/llm/llm_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@
within LLM context windows.
"""

import logging
import re
import sys
from pathlib import Path
from typing import List, Tuple

# Add project root to path
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))

from src.preprocessing.pdf_text_extraction import extract_text_from_pdf

log = logging.getLogger(__name__)

# Section headers commonly found in scientific diet / stomach-content papers.
# Order matters: earlier entries are higher priority when budget is tight.
SECTION_PATTERNS: List[re.Pattern[str]] = [
Expand Down Expand Up @@ -43,7 +45,6 @@ def split_into_pages(text: str) -> List[Tuple[int, str]]:
List of (page_number, page_text) tuples
"""
parts = re.split(r"\[PAGE\s+(\d+)\]", text)
# parts: [before_first_marker, page_num, page_text, page_num, page_text, ...]
pages: List[Tuple[int, str]] = []
if parts[0].strip():
pages.append((0, parts[0]))
Expand Down Expand Up @@ -71,8 +72,6 @@ def classify_page(page_text: str) -> Tuple[bool, int]:
for idx, pat in enumerate(SECTION_PATTERNS):
if pat.search(page_text):
return True, idx
# No recognised header — still potentially useful (e.g. tables without
# a "Table" header, continuation of Results, etc.)
return True, len(SECTION_PATTERNS)


Expand Down Expand Up @@ -100,13 +99,12 @@ def extract_key_sections(text: str, max_chars: int) -> str:
return text

pages = split_into_pages(text)
scored: List[Tuple[int, int, str]] = [] # (priority, page_num, page_text)
scored: List[Tuple[int, int, str]] = []
for page_num, page_text in pages:
useful, priority = classify_page(page_text)
if useful:
scored.append((priority, page_num, page_text))

# Sort by priority (ascending = most important first)
scored.sort(key=lambda t: t[0])

selected: List[Tuple[int, str]] = []
Expand All @@ -117,12 +115,10 @@ def extract_key_sections(text: str, max_chars: int) -> str:
selected.append((page_num, page_with_marker))
budget -= len(page_with_marker)
elif budget > 200:
# Partially include the page up to the remaining budget
selected.append((page_num, page_with_marker[:budget]))
budget = 0
break

# Re-sort by page number so the LLM sees content in reading order
selected.sort(key=lambda t: t[0])
return "\n".join(chunk for _, chunk in selected)

Expand Down Expand Up @@ -150,6 +146,8 @@ def load_document(file_path: Path) -> str:
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
except UnicodeDecodeError as e:
log.error("Text file encoding error for %s: %s", file_path, e)
raise RuntimeError(f"Text file encoding error: {e}")
else:
log.error("Unsupported file type attempted: %s", file_path.suffix)
raise RuntimeError(f"Unsupported file type: {suffix}. Use .pdf or .txt files.")
1 change: 1 addition & 0 deletions src/model/models/pdf_classifier.json

Large diffs are not rendered by default.

Loading
Loading