diff --git a/api/main.py b/api/main.py
index d0b8c79..35fba4b 100644
--- a/api/main.py
+++ b/api/main.py
@@ -1,7 +1,31 @@
from fastapi import FastAPI
+from fastapi.responses import HTMLResponse
from api.routes import templates, forms
-app = FastAPI()
+app = FastAPI(title="FireForm API")
+
+@app.get("/", response_class=HTMLResponse)
+def root():
+ return """
+
+
+ FireForm API
+
+
+
+
+
+
+ """
app.include_router(templates.router)
app.include_router(forms.router)
\ No newline at end of file
diff --git a/src/llm.py b/src/llm.py
index 70937f9..5a3dbde 100644
--- a/src/llm.py
+++ b/src/llm.py
@@ -23,58 +23,78 @@ def type_check_all(self):
Target fields must be a list. Input:\n\ttarget_fields: {self._target_fields}"
)
- def build_prompt(self, current_field):
+ def build_batch_prompt(self, field_list):
"""
- This method is in charge of the prompt engineering. It creates a specific prompt for each target field.
- @params: current_field -> represents the current element of the json that is being prompted.
+ Creates a prompt for batch extraction of all fields in JSON format.
"""
- prompt = f"""
+ prompt = f"""
SYSTEM PROMPT:
- You are an AI assistant designed to help fillout json files with information extracted from transcribed voice recordings.
- You will receive the transcription, and the name of the JSON field whose value you have to identify in the context. Return
- only a single string containing the identified value for the JSON field.
- If the field name is plural, and you identify more than one possible value in the text, return both separated by a ";".
- If you don't identify the value in the provided text, return "-1".
+ You are an AI assistant designed to extract information from transcribed voice recordings and format it as JSON.
+ You will receive the transcription and a list of JSON fields to identify.
+ Your output MUST be a valid JSON object where the keys are the field names and the values are the identified data.
+ If a value is not identified, use "-1".
+ If a field name is plural and you identify more than one value, use a ";" separated string.
+
+ Example format:
+ {{
+ "Field1": "value",
+ "Field2": "value1; value2",
+ "Field3": "-1"
+ }}
+
---
DATA:
- Target JSON field to find in text: {current_field}
-
+ Target JSON fields: {list(field_list)}
+
TEXT: {self._transcript_text}
"""
-
return prompt
def main_loop(self):
# self.type_check_all()
- for field in self._target_fields.keys():
- prompt = self.build_prompt(field)
- # print(prompt)
- # ollama_url = "http://localhost:11434/api/generate"
- ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/")
- ollama_url = f"{ollama_host}/api/generate"
-
- payload = {
- "model": "mistral",
- "prompt": prompt,
- "stream": False, # don't really know why --> look into this later.
- }
+ ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/")
+ ollama_url = f"{ollama_host}/api/generate"
+ model_name = os.getenv("OLLAMA_MODEL", "mistral")
+
+ prompt = self.build_batch_prompt(self._target_fields.keys())
+
+ payload = {
+ "model": model_name,
+ "prompt": prompt,
+ "stream": False,
+ "format": "json"
+ }
+
+ print(f"\t[LOG] Sending batch request to Ollama ({model_name})...")
+ try:
+ response = requests.post(ollama_url, json=payload, timeout=300)
+ response.raise_for_status()
+ json_data = response.json()
+ raw_response = json_data["response"]
+ # Parse the extracted JSON
try:
- response = requests.post(ollama_url, json=payload)
- response.raise_for_status()
- except requests.exceptions.ConnectionError:
- raise ConnectionError(
- f"Could not connect to Ollama at {ollama_url}. "
- "Please ensure Ollama is running and accessible."
- )
- except requests.exceptions.HTTPError as e:
- raise RuntimeError(f"Ollama returned an error: {e}")
-
- # parse response
- json_data = response.json()
- parsed_response = json_data["response"]
- # print(parsed_response)
- self.add_response_to_json(field, parsed_response)
+ extracted_data = json.loads(raw_response)
+ except json.JSONDecodeError:
+ # Fallback: find the first { and last }
+ start = raw_response.find('{')
+ end = raw_response.rfind('}')
+ if start != -1 and end != -1:
+ extracted_data = json.loads(raw_response[start:end+1])
+ else:
+ raise ValueError("Could not parse JSON from LLM response.")
+
+ # Process each field
+ for field, value in extracted_data.items():
+ self.add_response_to_json(field, str(value))
+
+ except requests.exceptions.ConnectionError:
+ raise ConnectionError(
+ f"Could not connect to Ollama at {ollama_url}. "
+ "Please ensure Ollama is running and accessible."
+ )
+ except Exception as e:
+ raise RuntimeError(f"Ollama/Extraction error: {e}")
print("----------------------------------")
print("\t[LOG] Resulting JSON created from the input text:")
diff --git a/src/main.py b/src/main.py
index 5bb632b..1c12af7 100644
--- a/src/main.py
+++ b/src/main.py
@@ -1,8 +1,8 @@
import os
-# from backend import Fill
-from commonforms import prepare_form
+from typing import Union
+from commonforms import prepare_form
from pypdf import PdfReader
-from controller import Controller
+from src.controller import Controller
def input_fields(num_fields: int):
fields = []