diff --git a/HOWTO.md b/HOWTO.md index 0e54c73..4c99aac 100644 --- a/HOWTO.md +++ b/HOWTO.md @@ -91,6 +91,18 @@ or +## Access MinIO console (object storage) + +Run +```bash +kubectl port-forward svc/minio 9001:9001 -n minio +``` + +Then open http://localhost:9001 + + user: minio_user + password: minio_password + ## Troubleshooting Access the pgsql via local db client diff --git a/backend/api/app.py b/backend/api/app.py index 774b53e..99bc6c2 100644 --- a/backend/api/app.py +++ b/backend/api/app.py @@ -14,6 +14,7 @@ from backend.api import ( auth_routes, + batch_routes, compliance_report_routes, demo_routes, deployed_models_routes, @@ -30,6 +31,8 @@ from backend.domain.use_cases.demo_usecases import SimulationManager from backend.domain.use_cases.ds_simulation_usecases import DSSimulationManager from backend.infrastructure.grafana_dashboard_adapter import GrafanaDashboardAdapter +from backend.infrastructure.k8s_batch_prediction_adapter import K8sBatchPredictionAdapter +from backend.infrastructure.minio_storage_adapter import MinioStorageAdapter from backend.infrastructure.mlflow_handler_adapter import MLFlowHandlerAdapter from backend.infrastructure.model_info_pgsql_db_handler import ModelInfoPostgresDBHandler from backend.infrastructure.platform_config_pgsql_adapter import PlatformConfigPgsqlAdapter @@ -61,7 +64,9 @@ async def lifespan(app: FastAPI): app.state.model_info_db_handler = ModelInfoPostgresDBHandler(db_config=config.pgsql_db_config) app.state.user_adapter = UserPgsqlDbAdapter(db_config=config.pgsql_db_config, admin_config=config.mp_admin_config) app.state.platform_config_handler = PlatformConfigPgsqlAdapter(db_config=config.pgsql_db_config) + app.state.object_storage_handler = MinioStorageAdapter() app.state.dashboard_handler = GrafanaDashboardAdapter() + app.state.batch_handler = K8sBatchPredictionAdapter() app.state.simulation_manager = SimulationManager() app.state.ds_simulation_manager = DSSimulationManager() app.state.task_status = {} @@ -103,6 +108,7 @@ def create_app() -> FastAPI: app.include_router(model_infos_routes.router, prefix="/model_infos", tags=["Model Infos"]) app.include_router(llm_routes.router, prefix="/ai", tags=["AI Assist"]) app.include_router(compliance_report_routes.router, prefix="/compliance", tags=["Compliance Report"]) + app.include_router(batch_routes.router, prefix="/{project_name}/batch", tags=["Batch Predictions"]) app.include_router(demo_routes.router, prefix="/demo", tags=["Demo Simulation"]) return app diff --git a/backend/api/batch_routes.py b/backend/api/batch_routes.py new file mode 100644 index 0000000..52c470b --- /dev/null +++ b/backend/api/batch_routes.py @@ -0,0 +1,246 @@ +# Philippe Stepniewski +import inspect +import uuid + +from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, Request, UploadFile +from loguru import logger +from starlette.responses import JSONResponse, Response + +from backend.domain.entities.batch_prediction import BatchPredictionStatus +from backend.domain.ports.batch_prediction_handler import BatchPredictionHandler +from backend.domain.ports.object_storage_handler import ObjectStorageHandler +from backend.domain.ports.project_db_handler import ProjectDbHandler +from backend.domain.ports.registry_handler import RegistryHandler +from backend.domain.ports.user_handler import UserHandler +from backend.domain.use_cases.auth_usecases import get_current_user, get_user_adapter +from backend.domain.use_cases.batch_predict import ( + cleanup_batch_predictions, + delete_batch_prediction, + download_batch_result, + get_batch_prediction_status, + list_batch_predictions, + submit_batch_prediction, +) +from backend.domain.use_cases.user_usecases import user_can_perform_action_for_project +from backend.utils import sanitize_project_name + +router = APIRouter() + + +def get_batch_handler(request: Request) -> BatchPredictionHandler: + return request.app.state.batch_handler + + +def get_project_db_handler(request: Request) -> ProjectDbHandler: + return request.app.state.project_db_handler + + +def get_object_storage_handler(request: Request) -> ObjectStorageHandler: + return request.app.state.object_storage_handler + + +def get_registry_pool(request: Request) -> RegistryHandler: + return request.app.state.registry_pool + + +def get_tasks_status(request: Request) -> dict: + return request.app.state.task_status + + +def _get_project_registry_tracking_uri(project_name: str) -> str: + sanitized = sanitize_project_name(project_name) + return f"http://{sanitized}.{sanitized}.svc.cluster.local:5000" + + +def _run_batch_submission( + tasks_status: dict, + job_id: str, + registry, + project_name: str, + model_name: str, + version: str, + file_content: bytes, + object_storage: ObjectStorageHandler, + batch_handler: BatchPredictionHandler, + project_db_handler: ProjectDbHandler, +): + try: + tasks_status[job_id] = BatchPredictionStatus.BUILDING.value + submit_batch_prediction( + project_name=project_name, + model_name=model_name, + version=version, + file_content=file_content, + job_id=job_id, + object_storage=object_storage, + batch_handler=batch_handler, + project_db_handler=project_db_handler, + registry=registry, + ) + del tasks_status[job_id] + except Exception as e: + logger.error(f"Batch submission failed for job {job_id}: {e}") + tasks_status[job_id] = BatchPredictionStatus.FAILED.value + + +@router.post("/submit/{model_name}/{version}") +async def route_submit_batch( + project_name: str, + model_name: str, + version: str, + request: Request, + background_tasks: BackgroundTasks, + file: UploadFile = File(...), + batch_handler: BatchPredictionHandler = Depends(get_batch_handler), + object_storage: ObjectStorageHandler = Depends(get_object_storage_handler), + project_db_handler: ProjectDbHandler = Depends(get_project_db_handler), + registry_pool: RegistryHandler = Depends(get_registry_pool), + tasks_status: dict = Depends(get_tasks_status), + user_adapter: UserHandler = Depends(get_user_adapter), + current_user: dict = Depends(get_current_user), +): + user_can_perform_action_for_project( + current_user, + project_name=project_name, + action_name=inspect.currentframe().f_code.co_name, + user_adapter=user_adapter, + ) + file_content = await file.read() + + registry = registry_pool.get_registry_adapter(project_name, _get_project_registry_tracking_uri(project_name)) + + job_id = str(uuid.uuid4())[:8] + tasks_status[job_id] = BatchPredictionStatus.BUILDING.value + + background_tasks.add_task( + _run_batch_submission, + tasks_status, + job_id, + registry, + project_name, + model_name, + version, + file_content, + object_storage, + batch_handler, + project_db_handler, + ) + + return JSONResponse( + content={"job_id": job_id, "status": BatchPredictionStatus.BUILDING.value}, + media_type="application/json", + ) + + +@router.get("/status/{job_id}") +def route_batch_status( + project_name: str, + job_id: str, + batch_handler: BatchPredictionHandler = Depends(get_batch_handler), + tasks_status: dict = Depends(get_tasks_status), + user_adapter: UserHandler = Depends(get_user_adapter), + current_user: dict = Depends(get_current_user), +): + user_can_perform_action_for_project( + current_user, + project_name=project_name, + action_name=inspect.currentframe().f_code.co_name, + user_adapter=user_adapter, + ) + # Check if the job is still in the build phase (tracked in-memory) + if job_id in tasks_status: + status = tasks_status[job_id] + if status == BatchPredictionStatus.BUILDING.value: + return JSONResponse( + content={"job_id": job_id, "status": BatchPredictionStatus.BUILDING.value}, + media_type="application/json", + ) + if status == BatchPredictionStatus.FAILED.value: + return JSONResponse( + content={"job_id": job_id, "status": BatchPredictionStatus.FAILED.value}, + media_type="application/json", + ) + + result = get_batch_prediction_status(project_name, job_id, batch_handler) + return JSONResponse(content=result, media_type="application/json") + + +@router.get("/list") +def route_list_batch( + project_name: str, + batch_handler: BatchPredictionHandler = Depends(get_batch_handler), + user_adapter: UserHandler = Depends(get_user_adapter), + current_user: dict = Depends(get_current_user), +): + user_can_perform_action_for_project( + current_user, + project_name=project_name, + action_name=inspect.currentframe().f_code.co_name, + user_adapter=user_adapter, + ) + result = list_batch_predictions(project_name, batch_handler) + return JSONResponse(content=result, media_type="application/json") + + +@router.get("/download/{job_id}") +def route_download_batch( + project_name: str, + job_id: str, + batch_handler: BatchPredictionHandler = Depends(get_batch_handler), + object_storage: ObjectStorageHandler = Depends(get_object_storage_handler), + user_adapter: UserHandler = Depends(get_user_adapter), + current_user: dict = Depends(get_current_user), +): + user_can_perform_action_for_project( + current_user, + project_name=project_name, + action_name=inspect.currentframe().f_code.co_name, + user_adapter=user_adapter, + ) + try: + content = download_batch_result(project_name, job_id, batch_handler, object_storage) + return Response( + content=content, + media_type="text/csv", + headers={"Content-Disposition": f"attachment; filename=predictions-{job_id}.csv"}, + ) + except Exception as e: + logger.error(f"Failed to download batch result: {e}") + raise HTTPException(status_code=404, detail="Batch result not found or not yet available") + + +@router.delete("/{job_id}") +def route_delete_batch( + project_name: str, + job_id: str, + batch_handler: BatchPredictionHandler = Depends(get_batch_handler), + object_storage: ObjectStorageHandler = Depends(get_object_storage_handler), + user_adapter: UserHandler = Depends(get_user_adapter), + current_user: dict = Depends(get_current_user), +): + user_can_perform_action_for_project( + current_user, + project_name=project_name, + action_name=inspect.currentframe().f_code.co_name, + user_adapter=user_adapter, + ) + result = delete_batch_prediction(project_name, job_id, batch_handler, object_storage) + return JSONResponse(content={"status": result}, media_type="application/json") + + +@router.post("/cleanup") +def route_cleanup_batch( + project_name: str, + batch_handler: BatchPredictionHandler = Depends(get_batch_handler), + object_storage: ObjectStorageHandler = Depends(get_object_storage_handler), + user_adapter: UserHandler = Depends(get_user_adapter), + current_user: dict = Depends(get_current_user), +): + user_can_perform_action_for_project( + current_user, + project_name=project_name, + action_name=inspect.currentframe().f_code.co_name, + user_adapter=user_adapter, + ) + deleted_count = cleanup_batch_predictions(project_name, batch_handler, object_storage) + return JSONResponse(content={"deleted": deleted_count}, media_type="application/json") diff --git a/backend/api/projects_routes.py b/backend/api/projects_routes.py index 064b54c..d77507e 100644 --- a/backend/api/projects_routes.py +++ b/backend/api/projects_routes.py @@ -8,6 +8,7 @@ from backend.domain.entities.project import Project from backend.domain.entities.role import Role from backend.domain.ports.model_registry import ModelRegistry +from backend.domain.ports.object_storage_handler import ObjectStorageHandler from backend.domain.ports.project_db_handler import ProjectDbHandler from backend.domain.ports.registry_handler import RegistryHandler from backend.domain.ports.user_handler import UserHandler @@ -24,6 +25,7 @@ list_projects, list_projects_for_user, remove_project, + update_project_batch_enabled, ) from backend.domain.use_cases.user_usecases import user_can_perform_action_for_project @@ -34,6 +36,10 @@ def get_project_db_handler(request: Request) -> ProjectDbHandler: return request.app.state.project_db_handler +def get_object_storage_handler(request: Request) -> ObjectStorageHandler: + return request.app.state.object_storage_handler + + @router.get("/list") def route_list_projects( project_sqlite_db_handler: ProjectDbHandler = Depends(get_project_db_handler), @@ -68,13 +74,14 @@ def route_project_info( def route_add_project( project: Project, project_sqlite_db_handler: ProjectDbHandler = Depends(get_project_db_handler), + object_storage: ObjectStorageHandler = Depends(get_object_storage_handler), user_adapter: UserHandler = Depends(get_user_adapter), current_user: dict = Depends(get_current_user), ) -> JSONResponse: user_can_perform_action_for_project( current_user, project_name="", action_name=inspect.currentframe().f_code.co_name, user_adapter=user_adapter ) - status = add_project(project_db_handler=project_sqlite_db_handler, project=project) + status = add_project(project_db_handler=project_sqlite_db_handler, project=project, object_storage=object_storage) return JSONResponse(content={"status": status}, media_type="application/json") @@ -82,13 +89,43 @@ def route_add_project( def route_remove_project( project_name: str, project_sqlite_db_handler: ProjectDbHandler = Depends(get_project_db_handler), + object_storage: ObjectStorageHandler = Depends(get_object_storage_handler), user_adapter: UserHandler = Depends(get_user_adapter), current_user: dict = Depends(get_current_user), ): user_can_perform_action_for_project( current_user, project_name="", action_name=inspect.currentframe().f_code.co_name, user_adapter=user_adapter ) - return remove_project(project_sqlite_db_handler, project_name=project_name) + return remove_project(project_sqlite_db_handler, project_name=project_name, object_storage=object_storage) + + +@router.patch("/{project_name}/batch_enabled") +def route_update_batch_enabled( + project_name: str, + body: dict, + project_sqlite_db_handler: ProjectDbHandler = Depends(get_project_db_handler), + object_storage: ObjectStorageHandler = Depends(get_object_storage_handler), + user_adapter: UserHandler = Depends(get_user_adapter), + current_user: dict = Depends(get_current_user), +): + user_can_perform_action_for_project( + current_user, + project_name=project_name, + action_name=inspect.currentframe().f_code.co_name, + user_adapter=user_adapter, + ) + batch_enabled = body.get("batch_enabled", False) + try: + status = update_project_batch_enabled( + project_db_handler=project_sqlite_db_handler, + project_name=project_name, + batch_enabled=batch_enabled, + object_storage=object_storage, + ) + except Exception as e: + logger.error(f"Failed to update batch_enabled for project '{project_name}': {e}") + raise HTTPException(status_code=500, detail=str(e)) + return JSONResponse(content={"status": status}, media_type="application/json") @router.post("/{project_name}/add_user") diff --git a/backend/domain/entities/batch_prediction.py b/backend/domain/entities/batch_prediction.py new file mode 100644 index 0000000..3bb29a5 --- /dev/null +++ b/backend/domain/entities/batch_prediction.py @@ -0,0 +1,45 @@ +# Philippe Stepniewski +from datetime import datetime +from enum import Enum +from typing import Optional + +from pydantic import BaseModel + + +class BatchPredictionStatus(str, Enum): + BUILDING = "building" + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + + +class BatchPrediction(BaseModel): + job_id: str + project_name: str + model_name: str + model_version: str + status: BatchPredictionStatus + input_path: str + output_path: str + created_at: datetime + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + error_message: Optional[str] = None + row_count: Optional[int] = None + + def to_json(self) -> dict: + return { + "job_id": self.job_id, + "project_name": self.project_name, + "model_name": self.model_name, + "model_version": self.model_version, + "status": self.status.value, + "input_path": self.input_path, + "output_path": self.output_path, + "created_at": self.created_at.isoformat(), + "started_at": self.started_at.isoformat() if self.started_at else None, + "completed_at": self.completed_at.isoformat() if self.completed_at else None, + "error_message": self.error_message, + "row_count": self.row_count, + } diff --git a/backend/domain/entities/docker/batch_predict_template.py b/backend/domain/entities/docker/batch_predict_template.py new file mode 100644 index 0000000..3a474f6 --- /dev/null +++ b/backend/domain/entities/docker/batch_predict_template.py @@ -0,0 +1,76 @@ +import os +import sys + +import boto3 +import mlflow +import pandas as pd +from loguru import logger + +logger.remove() +logger.add(sys.stderr, level="INFO") + + +def main(): + input_path = os.environ["INPUT_PATH"] + output_path = os.environ["OUTPUT_PATH"] + batch_bucket = os.environ.get("BATCH_BUCKET", "batch-predictions") + s3_endpoint = os.environ["MLFLOW_S3_ENDPOINT_URL"] + access_key = os.environ.get("AWS_ACCESS_KEY_ID", "minio_user") + secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "minio_password") + + s3 = boto3.client( + "s3", + endpoint_url=s3_endpoint, + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + ) + + logger.info(f"Downloading input file from {batch_bucket}/{input_path}") + response = s3.get_object(Bucket=batch_bucket, Key=input_path) + input_data = response["Body"].read() + + local_input = "/tmp/input.csv" + with open(local_input, "wb") as f: + f.write(input_data) + + logger.info("Loading model from /opt/mlflow/") + model = mlflow.pyfunc.load_model("/opt/mlflow/") + + # Read expected column types from MLflow model signature and build dtype map for CSV parsing + csv_dtype = None + if model.metadata and model.metadata.signature: + schema = model.metadata.signature.inputs + type_mapping = { + "double": "float64", + "float": "float32", + "long": "int64", + "integer": "int32", + "string": "object", + } + csv_dtype = {} + for col in schema.inputs: + mlflow_type = str(col.type) + csv_dtype[col.name] = type_mapping.get(mlflow_type, "float64") + logger.info(f"Using model signature to cast CSV columns: {csv_dtype}") + + logger.info("Running predictions by chunks of 1000 rows") + all_predictions = [] + for chunk in pd.read_csv(local_input, chunksize=1000, dtype=csv_dtype): + predictions = model.predict(chunk) + if hasattr(predictions, "tolist"): + predictions = predictions.tolist() + all_predictions.extend(predictions) + + output_df = pd.DataFrame({"prediction": all_predictions}) + local_output = "/tmp/output.csv" + output_df.to_csv(local_output, index=False) + + logger.info(f"Uploading results to {batch_bucket}/{output_path}") + with open(local_output, "rb") as f: + s3.put_object(Bucket=batch_bucket, Key=output_path, Body=f.read()) + + logger.info(f"Batch prediction completed: {len(all_predictions)} predictions written") + + +if __name__ == "__main__": + main() diff --git a/backend/domain/entities/docker/dockerfile_template.py b/backend/domain/entities/docker/dockerfile_template.py index 949a5ab..c08e240 100644 --- a/backend/domain/entities/docker/dockerfile_template.py +++ b/backend/domain/entities/docker/dockerfile_template.py @@ -27,6 +27,7 @@ def __init__(self, python_version: str): #Copy artefacts and dependencies lists COPY custom_model /opt/mlflow COPY fast_api_template.py /opt/mlflow + COPY batch_predict_template.py /opt/mlflow # Install python model version RUN YAML_PYTHON_VERSION=$(grep -E "^ *- python=" /opt/mlflow/conda.yaml \ @@ -37,7 +38,7 @@ def __init__(self, python_version: str): # Install additional dependencies in the environment RUN uv venv RUN uv pip install -r /opt/mlflow/requirements.txt - RUN uv pip install uvicorn fastapi cloudpickle loguru mlflow python-multipart + RUN uv pip install uvicorn fastapi cloudpickle loguru mlflow python-multipart boto3 RUN uv pip install opentelemetry-api opentelemetry-sdk opentelemetry-instrumentation-fastapi \ opentelemetry-exporter-prometheus diff --git a/backend/domain/entities/docker/utils.py b/backend/domain/entities/docker/utils.py index c014d0b..400a0df 100644 --- a/backend/domain/entities/docker/utils.py +++ b/backend/domain/entities/docker/utils.py @@ -115,6 +115,18 @@ def copy_fast_api_template_to_tmp_docker_folder(dest_path: str) -> None: shutil.copy(src_path, dest_path) +def copy_batch_predict_template_to_tmp_docker_folder(dest_path: str) -> None: + """ + Copies the batch predict template to the specified destination path. + + Args: + dest_path (str): The destination path where the batch predict template will be copied. + """ + src_path = os.path.join(PROJECT_DIR, "backend/domain/entities/docker/batch_predict_template.py") + logger.info(f"Copying batch predict template from {src_path} to {dest_path}") + shutil.copy(src_path, dest_path) + + def prepare_docker_context( registry: MLFlowModelRegistryAdapter, project_name: str, model_name: str, version: str ) -> str: @@ -132,6 +144,7 @@ def prepare_docker_context( """ path_dest = create_tmp_artefacts_folder(model_name, project_name, version, path=os.path.join(PROJECT_DIR, "tmp")) copy_fast_api_template_to_tmp_docker_folder(path_dest) + copy_batch_predict_template_to_tmp_docker_folder(path_dest) registry.download_model_artifacts(model_name, version, path_dest) return path_dest @@ -169,6 +182,30 @@ def clean_build_context(context_path: str) -> None: remove_directory(context_path) +def check_docker_image_exists(image_name: str) -> bool: + docker_host = os.environ.get("DOCKER_HOST") + logger.info(f"Checking if Docker image '{image_name}' exists with batch support (DOCKER_HOST={docker_host})") + try: + result = subprocess.run(["docker", "images", "-q", image_name], capture_output=True, text=True) + if not result.stdout.strip(): + logger.info(f"Docker image '{image_name}' does not exist") + return False + + # Verify the image contains the batch predict template (old images may not have it) + batch_template_path = "/opt/mlflow/batch_predict_template.py" + check_cmd = ["docker", "run", "--rm", "--entrypoint", "test", image_name, "-f", batch_template_path] + check = subprocess.run(check_cmd, capture_output=True) + has_batch = check.returncode == 0 + if not has_batch: + logger.info(f"Docker image '{image_name}' exists but lacks batch_predict_template.py, rebuild needed") + else: + logger.info(f"Docker image '{image_name}' exists with batch support") + return has_batch + except Exception as e: + logger.warning(f"Failed to check Docker image existence: {e}") + return False + + def sanitize_name(project_name: str) -> str: """Nettoie et format le nom pour être valid dans Kubernetes.""" sanitized_name = re.sub(r"[^a-z0-9-]", "-", project_name.lower()) diff --git a/backend/domain/entities/project.py b/backend/domain/entities/project.py index 6813ec2..ef60878 100644 --- a/backend/domain/entities/project.py +++ b/backend/domain/entities/project.py @@ -8,7 +8,14 @@ class Project(BaseModel): owner: str scope: str data_perimeter: str + batch_enabled: bool = False connection_parameters: Optional[str] = None def to_json(self) -> dict: - return {"data_perimeter": self.data_perimeter, "scope": self.scope, "owner": self.owner, "name": self.name} + return { + "data_perimeter": self.data_perimeter, + "scope": self.scope, + "owner": self.owner, + "name": self.name, + "batch_enabled": self.batch_enabled, + } diff --git a/backend/domain/ports/batch_prediction_handler.py b/backend/domain/ports/batch_prediction_handler.py new file mode 100644 index 0000000..1b646e6 --- /dev/null +++ b/backend/domain/ports/batch_prediction_handler.py @@ -0,0 +1,28 @@ +# Philippe Stepniewski +from abc import ABC, abstractmethod + +from backend.domain.entities.batch_prediction import BatchPrediction + + +class BatchPredictionHandler(ABC): + @abstractmethod + def create_batch_job( + self, project_name: str, model_name: str, model_version: str, input_path: str, output_path: str, job_id: str + ) -> BatchPrediction: + pass + + @abstractmethod + def get_job_status(self, project_name: str, job_id: str) -> BatchPrediction: + pass + + @abstractmethod + def list_batch_jobs(self, project_name: str) -> list[BatchPrediction]: + pass + + @abstractmethod + def delete_batch_job(self, project_name: str, job_id: str) -> bool: + pass + + @abstractmethod + def list_finished_jobs(self, project_name: str) -> list[BatchPrediction]: + pass diff --git a/backend/domain/ports/object_storage_handler.py b/backend/domain/ports/object_storage_handler.py new file mode 100644 index 0000000..6ce206f --- /dev/null +++ b/backend/domain/ports/object_storage_handler.py @@ -0,0 +1,32 @@ +# Philippe Stepniewski +from abc import ABC, abstractmethod + + +class ObjectStorageHandler(ABC): + @abstractmethod + def ensure_project_space(self, project_name: str) -> None: + pass + + @abstractmethod + def remove_project_space(self, project_name: str) -> None: + pass + + @abstractmethod + def upload_file(self, project_name: str, remote_path: str, file_content: bytes) -> None: + pass + + @abstractmethod + def download_file(self, project_name: str, remote_path: str) -> bytes: + pass + + @abstractmethod + def list_files(self, project_name: str, prefix: str = "") -> list[str]: + pass + + @abstractmethod + def delete_file(self, project_name: str, remote_path: str) -> None: + pass + + @abstractmethod + def file_exists(self, project_name: str, remote_path: str) -> bool: + pass diff --git a/backend/domain/ports/project_db_handler.py b/backend/domain/ports/project_db_handler.py index 2d6bb04..85f4aae 100644 --- a/backend/domain/ports/project_db_handler.py +++ b/backend/domain/ports/project_db_handler.py @@ -23,3 +23,7 @@ def add_project(self, project: Project) -> bool: @abstractmethod def remove_project(self, name) -> bool: pass + + @abstractmethod + def update_batch_enabled(self, name: str, batch_enabled: bool) -> bool: + pass diff --git a/backend/domain/use_cases/batch_predict.py b/backend/domain/use_cases/batch_predict.py new file mode 100644 index 0000000..d216bd4 --- /dev/null +++ b/backend/domain/use_cases/batch_predict.py @@ -0,0 +1,118 @@ +# Philippe Stepniewski +from fastapi import HTTPException +from loguru import logger + +from backend.domain.entities.docker.utils import build_model_docker_image, check_docker_image_exists, sanitize_name +from backend.domain.ports.batch_prediction_handler import BatchPredictionHandler +from backend.domain.ports.object_storage_handler import ObjectStorageHandler +from backend.domain.ports.project_db_handler import ProjectDbHandler + + +def ensure_model_image_exists(registry, project_name: str, model_name: str, version: str): + image_name = sanitize_name(f"{project_name}_{model_name}_{version}_ctr") + if check_docker_image_exists(image_name): + logger.info(f"Docker image '{image_name}' already exists, skipping build") + return + logger.info(f"Docker image '{image_name}' not found, building...") + build_status = build_model_docker_image(registry, project_name, model_name, version) + if build_status == 0: + raise HTTPException(status_code=500, detail="Failed to build model image") + logger.info(f"Docker image '{image_name}' built successfully") + + +def submit_batch_prediction( + project_name: str, + model_name: str, + version: str, + file_content: bytes, + job_id: str, + object_storage: ObjectStorageHandler, + batch_handler: BatchPredictionHandler, + project_db_handler: ProjectDbHandler, + registry=None, +): + project = project_db_handler.get_project(project_name) + if not project.batch_enabled: + raise HTTPException(status_code=400, detail="Batch predictions are not enabled for this project") + + input_path = f"{project_name}/{model_name}/{version}/{job_id}/input.csv" + output_path = f"{project_name}/{model_name}/{version}/{job_id}/predictions-{job_id}.csv" + + logger.info(f"Uploading input file to {input_path}") + object_storage.upload_file(project_name, f"{model_name}/{version}/{job_id}/input.csv", file_content) + + if registry: + ensure_model_image_exists(registry, project_name, model_name, version) + + batch_prediction = batch_handler.create_batch_job( + project_name, model_name, version, input_path, output_path, job_id + ) + return batch_prediction.to_json() + + +def get_batch_prediction_status(project_name: str, job_id: str, batch_handler: BatchPredictionHandler): + batch_prediction = batch_handler.get_job_status(project_name, job_id) + return batch_prediction.to_json() + + +def list_batch_predictions(project_name: str, batch_handler: BatchPredictionHandler): + jobs = batch_handler.list_batch_jobs(project_name) + return [job.to_json() for job in jobs] + + +def download_batch_result( + project_name: str, + job_id: str, + batch_handler: BatchPredictionHandler, + object_storage: ObjectStorageHandler, +): + batch_prediction = batch_handler.get_job_status(project_name, job_id) + model = batch_prediction.model_name + version = batch_prediction.model_version + output_remote_path = f"{model}/{version}/{job_id}/predictions-{job_id}.csv" + content = object_storage.download_file(project_name, output_remote_path) + return content + + +def delete_batch_prediction( + project_name: str, + job_id: str, + batch_handler: BatchPredictionHandler, + object_storage: ObjectStorageHandler, +): + batch_prediction = batch_handler.get_job_status(project_name, job_id) + model_name = batch_prediction.model_name + model_version = batch_prediction.model_version + + batch_handler.delete_batch_job(project_name, job_id) + + prefix = f"{model_name}/{model_version}/{job_id}/" + files = object_storage.list_files(project_name, prefix) + for f in files: + object_storage.delete_file(project_name, f) + + logger.info(f"Deleted batch prediction {job_id} and associated files") + return True + + +def cleanup_batch_predictions( + project_name: str, + batch_handler: BatchPredictionHandler, + object_storage: ObjectStorageHandler, +): + finished_jobs = batch_handler.list_finished_jobs(project_name) + deleted_count = 0 + for job in finished_jobs: + job_id = job.job_id + batch_handler.delete_batch_job(project_name, job_id) + + prefix = f"{job.model_name}/{job.model_version}/{job_id}/" + files = object_storage.list_files(project_name, prefix) + for f in files: + object_storage.delete_file(project_name, f) + + deleted_count += 1 + logger.info(f"Cleaned up batch job {job_id} ({job.status.value})") + + logger.info(f"Cleaned up {deleted_count} finished batch jobs for project {project_name}") + return deleted_count diff --git a/backend/domain/use_cases/projects_usecases.py b/backend/domain/use_cases/projects_usecases.py index c7344a2..4b1fad5 100644 --- a/backend/domain/use_cases/projects_usecases.py +++ b/backend/domain/use_cases/projects_usecases.py @@ -1,6 +1,8 @@ +# Philippe Stepniewski from loguru import logger from backend.domain.entities.project import Project +from backend.domain.ports.object_storage_handler import ObjectStorageHandler from backend.domain.ports.project_db_handler import ProjectDbHandler from backend.domain.use_cases.deploy_registry import deploy_registry from backend.domain.use_cases.deployed_models import _remove_project_namespace @@ -22,8 +24,10 @@ def list_projects_for_user(user: str, project_db_handler: ProjectDbHandler) -> l return l_projects -def add_project(project_db_handler: ProjectDbHandler, project: Project) -> bool: +def add_project(project_db_handler: ProjectDbHandler, project: Project, object_storage: ObjectStorageHandler) -> bool: deploy_registry(project.name) + if project.batch_enabled: + object_storage.ensure_project_space(project.name) status = project_db_handler.add_project(project) return status @@ -33,10 +37,30 @@ def get_project_info(project_db_handler: ProjectDbHandler, project_name: str) -> return project -def remove_project(project_db_handler: ProjectDbHandler, project_name: str) -> bool: +def remove_project( + project_db_handler: ProjectDbHandler, project_name: str, object_storage: ObjectStorageHandler +) -> bool: try: _remove_project_namespace(project_name) except Exception as e: logger.error(f"K8s cleanup failed for project '{project_name}', continuing with DB removal: {e}") + try: + object_storage.remove_project_space(project_name) + except Exception as e: + logger.error(f"Storage cleanup failed for project '{project_name}', continuing with DB removal: {e}") project_db_handler.remove_project(project_name) return True + + +def update_project_batch_enabled( + project_db_handler: ProjectDbHandler, + project_name: str, + batch_enabled: bool, + object_storage: ObjectStorageHandler, +) -> bool: + if batch_enabled: + object_storage.ensure_project_space(project_name) + else: + object_storage.remove_project_space(project_name) + project_db_handler.update_batch_enabled(project_name, batch_enabled) + return True diff --git a/backend/infrastructure/k8s_batch_prediction_adapter.py b/backend/infrastructure/k8s_batch_prediction_adapter.py new file mode 100644 index 0000000..ba5c2fe --- /dev/null +++ b/backend/infrastructure/k8s_batch_prediction_adapter.py @@ -0,0 +1,194 @@ +# Philippe Stepniewski +from datetime import datetime, timezone + +from kubernetes import client +from kubernetes.client.rest import ApiException +from loguru import logger + +from backend.domain.entities.batch_prediction import BatchPrediction, BatchPredictionStatus +from backend.domain.ports.batch_prediction_handler import BatchPredictionHandler +from backend.infrastructure.k8s_deployment import K8SDeployment +from backend.utils import sanitize_project_name + + +class K8sBatchPredictionAdapter(BatchPredictionHandler, K8SDeployment): + def __init__(self): + super().__init__() + self.batch_api = client.BatchV1Api() + + def create_batch_job( + self, project_name: str, model_name: str, model_version: str, input_path: str, output_path: str, job_id: str + ) -> BatchPrediction: + namespace = sanitize_project_name(project_name) + docker_image_name = sanitize_project_name(f"{project_name}_{model_name}_{model_version}_ctr") + + env_vars = [ + client.V1EnvVar(name="INPUT_PATH", value=input_path), + client.V1EnvVar(name="OUTPUT_PATH", value=output_path), + client.V1EnvVar(name="BATCH_BUCKET", value="batch-predictions"), + client.V1EnvVar(name="MLFLOW_S3_ENDPOINT_URL", value=self._get_env("MLFLOW_S3_ENDPOINT_URL")), + client.V1EnvVar(name="AWS_ACCESS_KEY_ID", value=self._get_env("AWS_ACCESS_KEY_ID", "minio_user")), + client.V1EnvVar( + name="AWS_SECRET_ACCESS_KEY", value=self._get_env("AWS_SECRET_ACCESS_KEY", "minio_password") + ), + ] + + job = client.V1Job( + metadata=client.V1ObjectMeta( + name=f"batch-{job_id}", + namespace=namespace, + labels={ + "app": "batch-prediction", + "project": sanitize_project_name(project_name), + "model": sanitize_project_name(model_name), + "version": sanitize_project_name(model_version), + "model-raw": model_name, + "version-raw": model_version, + "job_id": job_id, + }, + ), + spec=client.V1JobSpec( + backoff_limit=1, + ttl_seconds_after_finished=3600, + template=client.V1PodTemplateSpec( + metadata=client.V1ObjectMeta( + labels={ + "app": "batch-prediction", + "job_id": job_id, + } + ), + spec=client.V1PodSpec( + containers=[ + client.V1Container( + name="batch-predict", + image=f"{docker_image_name}:latest", + image_pull_policy="IfNotPresent", + command=["bash", "-c", "uv run python batch_predict_template.py"], + env=env_vars, + ) + ], + restart_policy="Never", + ), + ), + ), + ) + + self.batch_api.create_namespaced_job(namespace=namespace, body=job) + logger.info(f"Created batch job batch-{job_id} in namespace {namespace}") + + return BatchPrediction( + job_id=job_id, + project_name=project_name, + model_name=model_name, + model_version=model_version, + status=BatchPredictionStatus.PENDING, + input_path=input_path, + output_path=output_path, + created_at=datetime.now(timezone.utc), + ) + + def get_job_status(self, project_name: str, job_id: str) -> BatchPrediction: + namespace = sanitize_project_name(project_name) + job = self.batch_api.read_namespaced_job(name=f"batch-{job_id}", namespace=namespace) + return self._job_to_batch_prediction(job) + + def list_batch_jobs(self, project_name: str) -> list[BatchPrediction]: + namespace = sanitize_project_name(project_name) + jobs = self.batch_api.list_namespaced_job(namespace=namespace, label_selector="app=batch-prediction") + return [self._job_to_batch_prediction(job) for job in jobs.items] + + def delete_batch_job(self, project_name: str, job_id: str) -> bool: + namespace = sanitize_project_name(project_name) + try: + self.batch_api.delete_namespaced_job( + name=f"batch-{job_id}", + namespace=namespace, + body=client.V1DeleteOptions(propagation_policy="Background"), + ) + logger.info(f"Deleted batch job batch-{job_id} from namespace {namespace}") + return True + except ApiException as e: + if e.status == 404: + logger.warning(f"Batch job batch-{job_id} not found in namespace {namespace}") + return False + raise + + def list_finished_jobs(self, project_name: str) -> list[BatchPrediction]: + namespace = sanitize_project_name(project_name) + jobs = self.batch_api.list_namespaced_job(namespace=namespace, label_selector="app=batch-prediction") + finished = [] + for job in jobs.items: + bp = self._job_to_batch_prediction(job) + if bp.status in (BatchPredictionStatus.COMPLETED, BatchPredictionStatus.FAILED): + finished.append(bp) + return finished + + def _job_to_batch_prediction(self, job: client.V1Job) -> BatchPrediction: + labels = job.metadata.labels or {} + status = self._map_job_status(job.status) + + started_at = None + completed_at = None + error_message = None + + if job.status.start_time: + started_at = job.status.start_time + + if job.status.completion_time: + completed_at = job.status.completion_time + + if status == BatchPredictionStatus.FAILED: + error_message = self._get_pod_error_logs(job.metadata.namespace, job.metadata.name) + + env_vars = {} + containers = job.spec.template.spec.containers + if containers: + for env in containers[0].env or []: + env_vars[env.name] = env.value + + return BatchPrediction( + job_id=labels.get("job_id", ""), + project_name=labels.get("project", ""), + model_name=labels.get("model-raw", labels.get("model", "")), + model_version=labels.get("version-raw", labels.get("version", "")), + status=status, + input_path=env_vars.get("INPUT_PATH", ""), + output_path=env_vars.get("OUTPUT_PATH", ""), + created_at=job.metadata.creation_timestamp or datetime.now(timezone.utc), + started_at=started_at, + completed_at=completed_at, + error_message=error_message, + ) + + def _map_job_status(self, status: client.V1JobStatus) -> BatchPredictionStatus: + if status.succeeded and status.succeeded > 0: + return BatchPredictionStatus.COMPLETED + if status.failed and status.failed > 0: + return BatchPredictionStatus.FAILED + if status.active and status.active > 0: + return BatchPredictionStatus.RUNNING + return BatchPredictionStatus.PENDING + + def _get_pod_error_logs(self, namespace: str, job_name: str) -> str: + try: + pods = self.service_api_instance.list_namespaced_pod( + namespace=namespace, label_selector=f"job-name={job_name}" + ) + if not pods.items: + return "No pod found for this job" + pod = pods.items[0] + logs = self.service_api_instance.read_namespaced_pod_log( + name=pod.metadata.name, namespace=namespace, tail_lines=30 + ) + lines = [line for line in logs.strip().splitlines() if line.strip()] + if not lines: + return "No logs available" + return "\n".join(lines[-5:]) + except Exception as e: + logger.warning(f"Could not fetch pod logs for job {job_name}: {e}") + return "Could not retrieve error details" + + def _get_env(self, key: str, default: str = "") -> str: + import os + + return os.environ.get(key, default) diff --git a/backend/infrastructure/minio_storage_adapter.py b/backend/infrastructure/minio_storage_adapter.py new file mode 100644 index 0000000..bf02f0f --- /dev/null +++ b/backend/infrastructure/minio_storage_adapter.py @@ -0,0 +1,78 @@ +# Philippe Stepniewski +import os + +import boto3 +from botocore.exceptions import ClientError +from loguru import logger + +from backend.domain.ports.object_storage_handler import ObjectStorageHandler + +BATCH_BUCKET = "batch-predictions" + + +class MinioStorageAdapter(ObjectStorageHandler): + def __init__(self): + self.s3 = boto3.client( + "s3", + endpoint_url=os.environ["MLFLOW_S3_ENDPOINT_URL"], + aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID", "minio_user"), + aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY", "minio_password"), + ) + + def _ensure_bucket(self): + try: + self.s3.head_bucket(Bucket=BATCH_BUCKET) + except ClientError: + self.s3.create_bucket(Bucket=BATCH_BUCKET) + logger.info(f"Created bucket '{BATCH_BUCKET}'") + + def ensure_project_space(self, project_name: str) -> None: + self._ensure_bucket() + try: + self.s3.put_object(Bucket=BATCH_BUCKET, Key=f"{project_name}/.keep", Body=b"") + except ClientError as e: + logger.warning(f"Could not create marker for '{project_name}': {e}") + logger.info(f"Ensured project space for '{project_name}' in bucket '{BATCH_BUCKET}'") + + def remove_project_space(self, project_name: str) -> None: + prefix = f"{project_name}/" + paginator = self.s3.get_paginator("list_objects_v2") + for page in paginator.paginate(Bucket=BATCH_BUCKET, Prefix=prefix): + objects = page.get("Contents", []) + if objects: + self.s3.delete_objects( + Bucket=BATCH_BUCKET, + Delete={"Objects": [{"Key": obj["Key"]} for obj in objects]}, + ) + logger.info(f"Removed project space for '{project_name}' from bucket '{BATCH_BUCKET}'") + + def upload_file(self, project_name: str, remote_path: str, file_content: bytes) -> None: + self._ensure_bucket() + key = f"{project_name}/{remote_path}" + self.s3.put_object(Bucket=BATCH_BUCKET, Key=key, Body=file_content) + + def download_file(self, project_name: str, remote_path: str) -> bytes: + key = f"{project_name}/{remote_path}" + response = self.s3.get_object(Bucket=BATCH_BUCKET, Key=key) + return response["Body"].read() + + def list_files(self, project_name: str, prefix: str = "") -> list[str]: + full_prefix = f"{project_name}/{prefix}" + paginator = self.s3.get_paginator("list_objects_v2") + files = [] + for page in paginator.paginate(Bucket=BATCH_BUCKET, Prefix=full_prefix): + for obj in page.get("Contents", []): + files.append(obj["Key"].removeprefix(f"{project_name}/")) + return files + + def delete_file(self, project_name: str, remote_path: str) -> None: + key = f"{project_name}/{remote_path}" + self.s3.delete_object(Bucket=BATCH_BUCKET, Key=key) + + def file_exists(self, project_name: str, remote_path: str) -> bool: + key = f"{project_name}/{remote_path}" + try: + self.s3.head_object(Bucket=BATCH_BUCKET, Key=key) + return True + except ClientError: + return False diff --git a/backend/infrastructure/project_pgsql_db_handler.py b/backend/infrastructure/project_pgsql_db_handler.py index 8f8b212..1a2fe45 100644 --- a/backend/infrastructure/project_pgsql_db_handler.py +++ b/backend/infrastructure/project_pgsql_db_handler.py @@ -68,10 +68,12 @@ def add_project(self, project: Project) -> bool: try: cursor = connection.cursor() query = """ - INSERT INTO projects (name, owner, scope, data_perimeter) - VALUES (%s, %s, %s, %s) \ + INSERT INTO projects (name, owner, scope, data_perimeter, batch_enabled) + VALUES (%s, %s, %s, %s, %s) \ """ - cursor.execute(query, (project.name, project.owner, project.scope, project.data_perimeter)) + cursor.execute( + query, (project.name, project.owner, project.scope, project.data_perimeter, project.batch_enabled) + ) connection.commit() finally: connection.close() @@ -87,6 +89,16 @@ def remove_project(self, name): connection.close() return True + def update_batch_enabled(self, name: str, batch_enabled: bool) -> bool: + connection = self._connect() + try: + cursor = connection.cursor() + cursor.execute("UPDATE projects SET batch_enabled = %s WHERE name = %s", (batch_enabled, name)) + connection.commit() + finally: + connection.close() + return True + def _init_table_project_if_not_exists(self): connection = self._connect() try: @@ -117,6 +129,7 @@ def _init_table_project_if_not_exists(self): ) \ """ cursor.execute(query) + cursor.execute("ALTER TABLE projects ADD COLUMN IF NOT EXISTS batch_enabled BOOLEAN DEFAULT FALSE") connection.commit() finally: connection.close() diff --git a/backend/infrastructure/project_sqlite_db_handler.py b/backend/infrastructure/project_sqlite_db_handler.py index 1b3390b..91557aa 100644 --- a/backend/infrastructure/project_sqlite_db_handler.py +++ b/backend/infrastructure/project_sqlite_db_handler.py @@ -18,7 +18,16 @@ def __init__(self, message, name=None): def map_rows_to_projects(rows: list) -> list[Project]: - return [Project(name=row[1], owner=row[2], scope=row[3], data_perimeter=row[4]) for row in rows] + return [ + Project( + name=row[1], + owner=row[2], + scope=row[3], + data_perimeter=row[4], + batch_enabled=bool(row[5]) if len(row) > 5 else False, + ) + for row in rows + ] class ProjectSQLiteDBHandler(ProjectDbHandler): @@ -73,10 +82,10 @@ def add_project(self, project: Project) -> bool: cursor = connection.cursor() cursor.execute( """ - INSERT INTO projects (name, owner, scope, data_perimeter) - VALUES (?, ?, ?, ?) + INSERT INTO projects (name, owner, scope, data_perimeter, batch_enabled) + VALUES (?, ?, ?, ?, ?) """, - (project.name, project.owner, project.scope, project.data_perimeter), + (project.name, project.owner, project.scope, project.data_perimeter, project.batch_enabled), ) connection.commit() @@ -94,6 +103,16 @@ def remove_project(self, name): connection.close() return True + def update_batch_enabled(self, name: str, batch_enabled: bool) -> bool: + connection = sqlite3.connect(self.db_path) + try: + cursor = connection.cursor() + cursor.execute("UPDATE projects SET batch_enabled = ? WHERE name = ?", (batch_enabled, name)) + connection.commit() + finally: + connection.close() + return True + def _init_table_project_if_not_exists(self): connection = sqlite3.connect(self.db_path) try: @@ -106,10 +125,16 @@ def _init_table_project_if_not_exists(self): name TEXT NOT NULL, owner TEXT NOT NULL, scope TEXT NOT NULL, - data_perimeter TEXT NOT NULL + data_perimeter TEXT NOT NULL, + batch_enabled BOOLEAN DEFAULT FALSE ) """ ) + # Migration: add batch_enabled column if missing + cursor.execute("PRAGMA table_info(projects)") + columns = [col[1] for col in cursor.fetchall()] + if "batch_enabled" not in columns: + cursor.execute("ALTER TABLE projects ADD COLUMN batch_enabled BOOLEAN DEFAULT FALSE") connection.commit() finally: connection.close() diff --git a/cli/commands/batch.py b/cli/commands/batch.py new file mode 100644 index 0000000..e7839dc --- /dev/null +++ b/cli/commands/batch.py @@ -0,0 +1,69 @@ +# Philippe Stepniewski +import typer + +from cli.utils.api_calls import get_and_print +from cli.utils.token import get_client + + +def submit_batch( + project_name: str, + model_name: str, + version: str, + file_path: str = typer.Option(..., help="Path to the CSV file to process"), +): + """Submit a batch prediction job""" + client = get_client() + with open(file_path, "rb") as f: + r = client.post( + f"/{project_name}/batch/submit/{model_name}/{version}", + files={"file": (file_path.split("/")[-1], f, "text/csv")}, + ) + if r.status_code == 200: + result = r.json() + print(f"Batch job submitted successfully. Job ID: {result.get('job_id', 'unknown')}") + else: + print(f"Error submitting batch job: {r.text}") + + +def batch_status(project_name: str, job_id: str): + """Get the status of a batch prediction job""" + get_and_print( + f"/{project_name}/batch/status/{job_id}", + error_message="Error fetching batch job status", + success_message="Batch job status retrieved", + ) + + +def list_batch_jobs(project_name: str): + """List all batch prediction jobs for a project""" + get_and_print( + f"/{project_name}/batch/list", + error_message="Error listing batch jobs", + success_message="No batch jobs found", + ) + + +def download_batch_result( + project_name: str, + job_id: str, + output: str = typer.Option("predictions.csv", help="Output file path"), +): + """Download the result of a batch prediction job""" + client = get_client() + r = client.get(f"/{project_name}/batch/download/{job_id}") + if r.status_code == 200: + with open(output, "wb") as f: + f.write(r.content) + print(f"Results downloaded to {output}") + else: + print(f"Error downloading batch result: {r.text}") + + +def delete_batch_job(project_name: str, job_id: str): + """Delete a batch prediction job and its files""" + client = get_client() + r = client.delete(f"/{project_name}/batch/{job_id}") + if r.status_code == 200: + print("Batch job deleted successfully") + else: + print(f"Error deleting batch job: {r.text}") diff --git a/cli/commands/projects.py b/cli/commands/projects.py index 475ade2..d5dd0bc 100644 --- a/cli/commands/projects.py +++ b/cli/commands/projects.py @@ -1,6 +1,7 @@ +# Philippe Stepniewski import typer -from cli.utils.api_calls import get_and_print, post_and_print +from cli.utils.api_calls import get_and_print, patch_and_print, post_and_print from cli.utils.token import get_client @@ -13,8 +14,8 @@ def project_info(name: str): """Get detailed info about a project by name""" get_and_print( f"/projects/{name}/info", - error_message="❌ Error fetching project info", - success_message="✅ Project info retrieved successfully", + error_message="Error fetching project info", + success_message="Project info retrieved successfully", ) @@ -23,13 +24,21 @@ def add_project( owner: str = typer.Option(""), scope: str = typer.Option(""), data_perimeter: str = typer.Option(""), + batch_enabled: bool = typer.Option(False, help="Enable batch predictions for this project"), ): """Create a new project""" + payload = { + "name": name, + "owner": owner, + "scope": scope, + "data_perimeter": data_perimeter, + "batch_enabled": batch_enabled, + } post_and_print( "/projects/add", - {"name": name, "owner": owner, "scope": scope, "data_perimeter": data_perimeter}, - error_message="❌ Error creating project", - success_message="✅ Project created successfully", + payload, + error_message="Error creating project", + success_message="Project created successfully", ) @@ -37,8 +46,28 @@ def delete_project(name: str): """Delete a project by name""" get_and_print( f"/projects/{name}/remove", - error_message="❌ Error deleting project", - success_message="✅ Project deleted successfully", + error_message="Error deleting project", + success_message="Project deleted successfully", + ) + + +def enable_batch(project_name: str): + """Enable batch predictions for a project""" + patch_and_print( + f"/projects/{project_name}/batch_enabled", + {"batch_enabled": True}, + error_message="Error enabling batch predictions", + success_message="Batch predictions enabled successfully", + ) + + +def disable_batch(project_name: str): + """Disable batch predictions for a project""" + patch_and_print( + f"/projects/{project_name}/batch_enabled", + {"batch_enabled": False}, + error_message="Error disabling batch predictions", + success_message="Batch predictions disabled successfully", ) @@ -47,7 +76,7 @@ def add_user_to_project(project_name: str, email: str = typer.Option(), role: st client = get_client() r = client.post(f"/projects/{project_name}/add_user?email={email}&role={role}") if r.status_code == 200: - print("✅ User added to project successfully") + print("User added to project successfully") else: print(r.content) - print("❌ Error adding user to project") + print("Error adding user to project") diff --git a/cli/main.py b/cli/main.py index f042c41..ba1706d 100644 --- a/cli/main.py +++ b/cli/main.py @@ -1,18 +1,29 @@ import typer from cli.commands.auth import login, me +from cli.commands.batch import batch_status, delete_batch_job, download_batch_result, list_batch_jobs, submit_batch from cli.commands.demo import get_status, list_simulations, start_simulation, stop_simulation from cli.commands.models import deploy_model, list_deployed_models, list_models, search_model_infos, undeploy_model -from cli.commands.projects import add_project, add_user_to_project, delete_project, list_projects, project_info +from cli.commands.projects import ( + add_project, + add_user_to_project, + delete_project, + disable_batch, + enable_batch, + list_projects, + project_info, +) from cli.commands.users import add_user, list_users app = typer.Typer() project_app = typer.Typer() user_app = typer.Typer() demo_app = typer.Typer() +batch_app = typer.Typer() app.add_typer(project_app, name="projects") app.add_typer(user_app, name="users") app.add_typer(demo_app, name="demo") +app.add_typer(batch_app, name="batch") app.command()(login) app.command()(me) @@ -27,11 +38,18 @@ project_app.command("undeploy")(undeploy_model) project_app.command("list-deployed-models")(list_deployed_models) project_app.command("delete")(delete_project) +project_app.command("enable-batch")(enable_batch) +project_app.command("disable-batch")(disable_batch) user_app.command("list")(list_users) user_app.command("add")(add_user) demo_app.command("list")(list_simulations) demo_app.command("start")(start_simulation) demo_app.command("stop")(stop_simulation) demo_app.command("status")(get_status) +batch_app.command("submit")(submit_batch) +batch_app.command("status")(batch_status) +batch_app.command("list")(list_batch_jobs) +batch_app.command("download")(download_batch_result) +batch_app.command("delete")(delete_batch_job) if __name__ == "__main__": app() diff --git a/cli/utils/api_calls.py b/cli/utils/api_calls.py index e235660..b6c1299 100644 --- a/cli/utils/api_calls.py +++ b/cli/utils/api_calls.py @@ -45,3 +45,15 @@ def post_and_print( else: print(r.content) print(error_message) + + +def patch_and_print( + endpoint: str, payload: dict, error_message: str = "❌ Error", success_message: str = "✅Success" +) -> None: + client = get_client() + r = client.patch(endpoint, json=payload) + if r.status_code == 200: + print(success_message) + else: + print(r.content) + print(error_message) diff --git a/demos/batch_prediction_test_credit_default.csv b/demos/batch_prediction_test_credit_default.csv new file mode 100644 index 0000000..ae67f3c --- /dev/null +++ b/demos/batch_prediction_test_credit_default.csv @@ -0,0 +1,21 @@ +age,income,loan_amount,loan_duration_months,credit_score,num_existing_loans,employment_years,missed_payments_12m,debt_to_income_ratio,loan_to_income_ratio +35,55000,15000,36,720,1,8,0,0.2727,0.2727 +28,32000,8000,24,650,2,3,1,0.25,0.25 +45,85000,40000,60,780,0,18,0,0.4706,0.4706 +52,120000,60000,84,810,1,25,0,0.5,0.5 +23,22000,5000,12,580,3,1,3,0.2273,0.2273 +61,95000,30000,48,750,0,30,0,0.3158,0.3158 +30,40000,20000,36,690,2,5,0,0.5,0.5 +42,70000,35000,60,710,1,15,1,0.5,0.5 +55,110000,25000,24,830,0,28,0,0.2273,0.2273 +26,28000,12000,36,620,4,2,2,0.4286,0.4286 +38,60000,18000,48,740,1,10,0,0.3,0.3 +48,90000,50000,84,700,2,20,1,0.5556,0.5556 +33,45000,10000,24,680,0,7,0,0.2222,0.2222 +29,35000,15000,36,640,3,4,2,0.4286,0.4286 +57,105000,20000,12,800,0,32,0,0.1905,0.1905 +40,65000,30000,60,730,1,12,0,0.4615,0.4615 +24,25000,7000,24,600,2,1,1,0.28,0.28 +50,100000,45000,84,770,0,22,0,0.45,0.45 +36,52000,22000,48,710,3,9,1,0.4231,0.4231 +44,78000,16000,36,760,0,16,0,0.2051,0.2051 diff --git a/frontend/js/api.js b/frontend/js/api.js index 23913bc..3d86fd5 100644 --- a/frontend/js/api.js +++ b/frontend/js/api.js @@ -82,6 +82,8 @@ const API = (() => { headers: { Authorization: `Bearer ${Auth.getToken()}` }, }).then(r => r.blob()), + updateBatchEnabled: (name, enabled) => + request('PATCH', `/projects/${name}/batch_enabled`, { body: { batch_enabled: enabled } }), registryStatus: (name) => get(`/projects/${name}/registry_status`).then(r => r.status).catch(() => 'error'), // User management @@ -223,6 +225,34 @@ const API = (() => { }), }, + // ── Batch Predictions ────────────────────────────────────── + batch: { + submit(proj, modelName, version, file) { + const fd = new FormData(); + fd.append('file', file); + return fetch(`${API_BASE}/${enc(proj)}/batch/submit/${enc(modelName)}/${enc(version)}`, { + method: 'POST', + headers: { Authorization: `Bearer ${Auth.getToken()}` }, + body: fd, + }).then(async r => { + if (!r.ok) throw new Error((await r.json()).detail || r.statusText); + return r.json(); + }); + }, + status: (proj, jobId) => get(`/${enc(proj)}/batch/status/${enc(jobId)}`), + list: (proj) => get(`/${enc(proj)}/batch/list`), + download(proj, jobId) { + return fetch(`${API_BASE}/${enc(proj)}/batch/download/${enc(jobId)}`, { + headers: { Authorization: `Bearer ${Auth.getToken()}` }, + }).then(async r => { + if (!r.ok) throw new Error((await r.json()).detail || r.statusText); + return r.blob(); + }); + }, + delete: (proj, jobId) => request('DELETE', `/${enc(proj)}/batch/${enc(jobId)}`), + cleanup: (proj) => request('POST', `/${enc(proj)}/batch/cleanup`), + }, + // ── Health ──────────────────────────────────────────────── health: { check: () => get('/health'), diff --git a/frontend/js/pages/project-detail.js b/frontend/js/pages/project-detail.js index 49d60c3..5a3ea19 100644 --- a/frontend/js/pages/project-detail.js +++ b/frontend/js/pages/project-detail.js @@ -39,12 +39,17 @@ const ProjectDetailPage = (() => { Deployed Models +
+ `; @@ -84,6 +89,7 @@ const ProjectDetailPage = (() => { case 'models': loadModels(projectName, panel); break; case 'registry': loadRegistry(projectName, panel); break; case 'deployed': loadDeployed(projectName, panel); break; + case 'batch': loadBatch(projectName, panel); break; } } @@ -107,6 +113,7 @@ const ProjectDetailPage = (() => { const owner = info.owner || info.Owner || '—'; const scope = info.scope || info.Scope || '—'; const perimeter = info.data_perimeter || info['Data Perimeter'] || '—'; + const batchEnabled = info.batch_enabled || false; panel.innerHTML = `No models available. Register models in MLflow to submit batch predictions.
' + : ''; + + const rows = (jobs || []).map(j => batchJobRow(j, projectName)).join(''); + + panel.innerHTML = ` +| Job ID | Model | Version | Status | Created | Actions |
|---|
${escHtml(btn.dataset.error)}`,
+ confirmLabel: 'OK',
+ });
+ });
+ });
+
+ // Download buttons
+ panel.querySelectorAll('.batch-download-btn').forEach(btn => {
+ btn.addEventListener('click', async () => {
+ const jobId = btn.dataset.jobId;
+ try {
+ const blob = await API.batch.download(projectName, jobId);
+ const url = URL.createObjectURL(blob);
+ const a = document.createElement('a');
+ a.href = url;
+ a.download = `predictions-${jobId}.csv`;
+ document.body.appendChild(a);
+ a.click();
+ document.body.removeChild(a);
+ URL.revokeObjectURL(url);
+ } catch (err) { Toast.error(err.message); }
+ });
+ });
+
+ // Poll for building/pending/running jobs
+ (jobs || []).forEach(j => {
+ const s = (j.status || '').toLowerCase();
+ if (s === 'building' || s === 'pending' || s === 'running') {
+ pollBatchStatus(projectName, j.job_id, panel);
+ }
+ });
+ }
+
+ function batchJobRow(j, projectName) {
+ const jobId = j.job_id || '—';
+ const model = j.model_name || '—';
+ const version = j.model_version || '—';
+ const status = j.status || 'unknown';
+ const created = j.created_at ? new Date(j.created_at).toLocaleString() : '—';
+ const isCompleted = status.toLowerCase() === 'completed';
+ const isFailed = status.toLowerCase() === 'failed';
+ const errorMsg = j.error_message || '';
+
+ return `
+