diff --git a/HOWTO.md b/HOWTO.md index 0e54c73..4c99aac 100644 --- a/HOWTO.md +++ b/HOWTO.md @@ -91,6 +91,18 @@ or +## Access MinIO console (object storage) + +Run +```bash +kubectl port-forward svc/minio 9001:9001 -n minio +``` + +Then open http://localhost:9001 + + user: minio_user + password: minio_password + ## Troubleshooting Access the pgsql via local db client diff --git a/backend/api/app.py b/backend/api/app.py index 774b53e..99bc6c2 100644 --- a/backend/api/app.py +++ b/backend/api/app.py @@ -14,6 +14,7 @@ from backend.api import ( auth_routes, + batch_routes, compliance_report_routes, demo_routes, deployed_models_routes, @@ -30,6 +31,8 @@ from backend.domain.use_cases.demo_usecases import SimulationManager from backend.domain.use_cases.ds_simulation_usecases import DSSimulationManager from backend.infrastructure.grafana_dashboard_adapter import GrafanaDashboardAdapter +from backend.infrastructure.k8s_batch_prediction_adapter import K8sBatchPredictionAdapter +from backend.infrastructure.minio_storage_adapter import MinioStorageAdapter from backend.infrastructure.mlflow_handler_adapter import MLFlowHandlerAdapter from backend.infrastructure.model_info_pgsql_db_handler import ModelInfoPostgresDBHandler from backend.infrastructure.platform_config_pgsql_adapter import PlatformConfigPgsqlAdapter @@ -61,7 +64,9 @@ async def lifespan(app: FastAPI): app.state.model_info_db_handler = ModelInfoPostgresDBHandler(db_config=config.pgsql_db_config) app.state.user_adapter = UserPgsqlDbAdapter(db_config=config.pgsql_db_config, admin_config=config.mp_admin_config) app.state.platform_config_handler = PlatformConfigPgsqlAdapter(db_config=config.pgsql_db_config) + app.state.object_storage_handler = MinioStorageAdapter() app.state.dashboard_handler = GrafanaDashboardAdapter() + app.state.batch_handler = K8sBatchPredictionAdapter() app.state.simulation_manager = SimulationManager() app.state.ds_simulation_manager = DSSimulationManager() app.state.task_status = {} @@ -103,6 +108,7 @@ def create_app() -> FastAPI: app.include_router(model_infos_routes.router, prefix="/model_infos", tags=["Model Infos"]) app.include_router(llm_routes.router, prefix="/ai", tags=["AI Assist"]) app.include_router(compliance_report_routes.router, prefix="/compliance", tags=["Compliance Report"]) + app.include_router(batch_routes.router, prefix="/{project_name}/batch", tags=["Batch Predictions"]) app.include_router(demo_routes.router, prefix="/demo", tags=["Demo Simulation"]) return app diff --git a/backend/api/batch_routes.py b/backend/api/batch_routes.py new file mode 100644 index 0000000..52c470b --- /dev/null +++ b/backend/api/batch_routes.py @@ -0,0 +1,246 @@ +# Philippe Stepniewski +import inspect +import uuid + +from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, Request, UploadFile +from loguru import logger +from starlette.responses import JSONResponse, Response + +from backend.domain.entities.batch_prediction import BatchPredictionStatus +from backend.domain.ports.batch_prediction_handler import BatchPredictionHandler +from backend.domain.ports.object_storage_handler import ObjectStorageHandler +from backend.domain.ports.project_db_handler import ProjectDbHandler +from backend.domain.ports.registry_handler import RegistryHandler +from backend.domain.ports.user_handler import UserHandler +from backend.domain.use_cases.auth_usecases import get_current_user, get_user_adapter +from backend.domain.use_cases.batch_predict import ( + cleanup_batch_predictions, + delete_batch_prediction, + download_batch_result, + get_batch_prediction_status, + list_batch_predictions, + submit_batch_prediction, +) +from backend.domain.use_cases.user_usecases import user_can_perform_action_for_project +from backend.utils import sanitize_project_name + +router = APIRouter() + + +def get_batch_handler(request: Request) -> BatchPredictionHandler: + return request.app.state.batch_handler + + +def get_project_db_handler(request: Request) -> ProjectDbHandler: + return request.app.state.project_db_handler + + +def get_object_storage_handler(request: Request) -> ObjectStorageHandler: + return request.app.state.object_storage_handler + + +def get_registry_pool(request: Request) -> RegistryHandler: + return request.app.state.registry_pool + + +def get_tasks_status(request: Request) -> dict: + return request.app.state.task_status + + +def _get_project_registry_tracking_uri(project_name: str) -> str: + sanitized = sanitize_project_name(project_name) + return f"http://{sanitized}.{sanitized}.svc.cluster.local:5000" + + +def _run_batch_submission( + tasks_status: dict, + job_id: str, + registry, + project_name: str, + model_name: str, + version: str, + file_content: bytes, + object_storage: ObjectStorageHandler, + batch_handler: BatchPredictionHandler, + project_db_handler: ProjectDbHandler, +): + try: + tasks_status[job_id] = BatchPredictionStatus.BUILDING.value + submit_batch_prediction( + project_name=project_name, + model_name=model_name, + version=version, + file_content=file_content, + job_id=job_id, + object_storage=object_storage, + batch_handler=batch_handler, + project_db_handler=project_db_handler, + registry=registry, + ) + del tasks_status[job_id] + except Exception as e: + logger.error(f"Batch submission failed for job {job_id}: {e}") + tasks_status[job_id] = BatchPredictionStatus.FAILED.value + + +@router.post("/submit/{model_name}/{version}") +async def route_submit_batch( + project_name: str, + model_name: str, + version: str, + request: Request, + background_tasks: BackgroundTasks, + file: UploadFile = File(...), + batch_handler: BatchPredictionHandler = Depends(get_batch_handler), + object_storage: ObjectStorageHandler = Depends(get_object_storage_handler), + project_db_handler: ProjectDbHandler = Depends(get_project_db_handler), + registry_pool: RegistryHandler = Depends(get_registry_pool), + tasks_status: dict = Depends(get_tasks_status), + user_adapter: UserHandler = Depends(get_user_adapter), + current_user: dict = Depends(get_current_user), +): + user_can_perform_action_for_project( + current_user, + project_name=project_name, + action_name=inspect.currentframe().f_code.co_name, + user_adapter=user_adapter, + ) + file_content = await file.read() + + registry = registry_pool.get_registry_adapter(project_name, _get_project_registry_tracking_uri(project_name)) + + job_id = str(uuid.uuid4())[:8] + tasks_status[job_id] = BatchPredictionStatus.BUILDING.value + + background_tasks.add_task( + _run_batch_submission, + tasks_status, + job_id, + registry, + project_name, + model_name, + version, + file_content, + object_storage, + batch_handler, + project_db_handler, + ) + + return JSONResponse( + content={"job_id": job_id, "status": BatchPredictionStatus.BUILDING.value}, + media_type="application/json", + ) + + +@router.get("/status/{job_id}") +def route_batch_status( + project_name: str, + job_id: str, + batch_handler: BatchPredictionHandler = Depends(get_batch_handler), + tasks_status: dict = Depends(get_tasks_status), + user_adapter: UserHandler = Depends(get_user_adapter), + current_user: dict = Depends(get_current_user), +): + user_can_perform_action_for_project( + current_user, + project_name=project_name, + action_name=inspect.currentframe().f_code.co_name, + user_adapter=user_adapter, + ) + # Check if the job is still in the build phase (tracked in-memory) + if job_id in tasks_status: + status = tasks_status[job_id] + if status == BatchPredictionStatus.BUILDING.value: + return JSONResponse( + content={"job_id": job_id, "status": BatchPredictionStatus.BUILDING.value}, + media_type="application/json", + ) + if status == BatchPredictionStatus.FAILED.value: + return JSONResponse( + content={"job_id": job_id, "status": BatchPredictionStatus.FAILED.value}, + media_type="application/json", + ) + + result = get_batch_prediction_status(project_name, job_id, batch_handler) + return JSONResponse(content=result, media_type="application/json") + + +@router.get("/list") +def route_list_batch( + project_name: str, + batch_handler: BatchPredictionHandler = Depends(get_batch_handler), + user_adapter: UserHandler = Depends(get_user_adapter), + current_user: dict = Depends(get_current_user), +): + user_can_perform_action_for_project( + current_user, + project_name=project_name, + action_name=inspect.currentframe().f_code.co_name, + user_adapter=user_adapter, + ) + result = list_batch_predictions(project_name, batch_handler) + return JSONResponse(content=result, media_type="application/json") + + +@router.get("/download/{job_id}") +def route_download_batch( + project_name: str, + job_id: str, + batch_handler: BatchPredictionHandler = Depends(get_batch_handler), + object_storage: ObjectStorageHandler = Depends(get_object_storage_handler), + user_adapter: UserHandler = Depends(get_user_adapter), + current_user: dict = Depends(get_current_user), +): + user_can_perform_action_for_project( + current_user, + project_name=project_name, + action_name=inspect.currentframe().f_code.co_name, + user_adapter=user_adapter, + ) + try: + content = download_batch_result(project_name, job_id, batch_handler, object_storage) + return Response( + content=content, + media_type="text/csv", + headers={"Content-Disposition": f"attachment; filename=predictions-{job_id}.csv"}, + ) + except Exception as e: + logger.error(f"Failed to download batch result: {e}") + raise HTTPException(status_code=404, detail="Batch result not found or not yet available") + + +@router.delete("/{job_id}") +def route_delete_batch( + project_name: str, + job_id: str, + batch_handler: BatchPredictionHandler = Depends(get_batch_handler), + object_storage: ObjectStorageHandler = Depends(get_object_storage_handler), + user_adapter: UserHandler = Depends(get_user_adapter), + current_user: dict = Depends(get_current_user), +): + user_can_perform_action_for_project( + current_user, + project_name=project_name, + action_name=inspect.currentframe().f_code.co_name, + user_adapter=user_adapter, + ) + result = delete_batch_prediction(project_name, job_id, batch_handler, object_storage) + return JSONResponse(content={"status": result}, media_type="application/json") + + +@router.post("/cleanup") +def route_cleanup_batch( + project_name: str, + batch_handler: BatchPredictionHandler = Depends(get_batch_handler), + object_storage: ObjectStorageHandler = Depends(get_object_storage_handler), + user_adapter: UserHandler = Depends(get_user_adapter), + current_user: dict = Depends(get_current_user), +): + user_can_perform_action_for_project( + current_user, + project_name=project_name, + action_name=inspect.currentframe().f_code.co_name, + user_adapter=user_adapter, + ) + deleted_count = cleanup_batch_predictions(project_name, batch_handler, object_storage) + return JSONResponse(content={"deleted": deleted_count}, media_type="application/json") diff --git a/backend/api/projects_routes.py b/backend/api/projects_routes.py index 064b54c..d77507e 100644 --- a/backend/api/projects_routes.py +++ b/backend/api/projects_routes.py @@ -8,6 +8,7 @@ from backend.domain.entities.project import Project from backend.domain.entities.role import Role from backend.domain.ports.model_registry import ModelRegistry +from backend.domain.ports.object_storage_handler import ObjectStorageHandler from backend.domain.ports.project_db_handler import ProjectDbHandler from backend.domain.ports.registry_handler import RegistryHandler from backend.domain.ports.user_handler import UserHandler @@ -24,6 +25,7 @@ list_projects, list_projects_for_user, remove_project, + update_project_batch_enabled, ) from backend.domain.use_cases.user_usecases import user_can_perform_action_for_project @@ -34,6 +36,10 @@ def get_project_db_handler(request: Request) -> ProjectDbHandler: return request.app.state.project_db_handler +def get_object_storage_handler(request: Request) -> ObjectStorageHandler: + return request.app.state.object_storage_handler + + @router.get("/list") def route_list_projects( project_sqlite_db_handler: ProjectDbHandler = Depends(get_project_db_handler), @@ -68,13 +74,14 @@ def route_project_info( def route_add_project( project: Project, project_sqlite_db_handler: ProjectDbHandler = Depends(get_project_db_handler), + object_storage: ObjectStorageHandler = Depends(get_object_storage_handler), user_adapter: UserHandler = Depends(get_user_adapter), current_user: dict = Depends(get_current_user), ) -> JSONResponse: user_can_perform_action_for_project( current_user, project_name="", action_name=inspect.currentframe().f_code.co_name, user_adapter=user_adapter ) - status = add_project(project_db_handler=project_sqlite_db_handler, project=project) + status = add_project(project_db_handler=project_sqlite_db_handler, project=project, object_storage=object_storage) return JSONResponse(content={"status": status}, media_type="application/json") @@ -82,13 +89,43 @@ def route_add_project( def route_remove_project( project_name: str, project_sqlite_db_handler: ProjectDbHandler = Depends(get_project_db_handler), + object_storage: ObjectStorageHandler = Depends(get_object_storage_handler), user_adapter: UserHandler = Depends(get_user_adapter), current_user: dict = Depends(get_current_user), ): user_can_perform_action_for_project( current_user, project_name="", action_name=inspect.currentframe().f_code.co_name, user_adapter=user_adapter ) - return remove_project(project_sqlite_db_handler, project_name=project_name) + return remove_project(project_sqlite_db_handler, project_name=project_name, object_storage=object_storage) + + +@router.patch("/{project_name}/batch_enabled") +def route_update_batch_enabled( + project_name: str, + body: dict, + project_sqlite_db_handler: ProjectDbHandler = Depends(get_project_db_handler), + object_storage: ObjectStorageHandler = Depends(get_object_storage_handler), + user_adapter: UserHandler = Depends(get_user_adapter), + current_user: dict = Depends(get_current_user), +): + user_can_perform_action_for_project( + current_user, + project_name=project_name, + action_name=inspect.currentframe().f_code.co_name, + user_adapter=user_adapter, + ) + batch_enabled = body.get("batch_enabled", False) + try: + status = update_project_batch_enabled( + project_db_handler=project_sqlite_db_handler, + project_name=project_name, + batch_enabled=batch_enabled, + object_storage=object_storage, + ) + except Exception as e: + logger.error(f"Failed to update batch_enabled for project '{project_name}': {e}") + raise HTTPException(status_code=500, detail=str(e)) + return JSONResponse(content={"status": status}, media_type="application/json") @router.post("/{project_name}/add_user") diff --git a/backend/domain/entities/batch_prediction.py b/backend/domain/entities/batch_prediction.py new file mode 100644 index 0000000..3bb29a5 --- /dev/null +++ b/backend/domain/entities/batch_prediction.py @@ -0,0 +1,45 @@ +# Philippe Stepniewski +from datetime import datetime +from enum import Enum +from typing import Optional + +from pydantic import BaseModel + + +class BatchPredictionStatus(str, Enum): + BUILDING = "building" + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + + +class BatchPrediction(BaseModel): + job_id: str + project_name: str + model_name: str + model_version: str + status: BatchPredictionStatus + input_path: str + output_path: str + created_at: datetime + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + error_message: Optional[str] = None + row_count: Optional[int] = None + + def to_json(self) -> dict: + return { + "job_id": self.job_id, + "project_name": self.project_name, + "model_name": self.model_name, + "model_version": self.model_version, + "status": self.status.value, + "input_path": self.input_path, + "output_path": self.output_path, + "created_at": self.created_at.isoformat(), + "started_at": self.started_at.isoformat() if self.started_at else None, + "completed_at": self.completed_at.isoformat() if self.completed_at else None, + "error_message": self.error_message, + "row_count": self.row_count, + } diff --git a/backend/domain/entities/docker/batch_predict_template.py b/backend/domain/entities/docker/batch_predict_template.py new file mode 100644 index 0000000..3a474f6 --- /dev/null +++ b/backend/domain/entities/docker/batch_predict_template.py @@ -0,0 +1,76 @@ +import os +import sys + +import boto3 +import mlflow +import pandas as pd +from loguru import logger + +logger.remove() +logger.add(sys.stderr, level="INFO") + + +def main(): + input_path = os.environ["INPUT_PATH"] + output_path = os.environ["OUTPUT_PATH"] + batch_bucket = os.environ.get("BATCH_BUCKET", "batch-predictions") + s3_endpoint = os.environ["MLFLOW_S3_ENDPOINT_URL"] + access_key = os.environ.get("AWS_ACCESS_KEY_ID", "minio_user") + secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "minio_password") + + s3 = boto3.client( + "s3", + endpoint_url=s3_endpoint, + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + ) + + logger.info(f"Downloading input file from {batch_bucket}/{input_path}") + response = s3.get_object(Bucket=batch_bucket, Key=input_path) + input_data = response["Body"].read() + + local_input = "/tmp/input.csv" + with open(local_input, "wb") as f: + f.write(input_data) + + logger.info("Loading model from /opt/mlflow/") + model = mlflow.pyfunc.load_model("/opt/mlflow/") + + # Read expected column types from MLflow model signature and build dtype map for CSV parsing + csv_dtype = None + if model.metadata and model.metadata.signature: + schema = model.metadata.signature.inputs + type_mapping = { + "double": "float64", + "float": "float32", + "long": "int64", + "integer": "int32", + "string": "object", + } + csv_dtype = {} + for col in schema.inputs: + mlflow_type = str(col.type) + csv_dtype[col.name] = type_mapping.get(mlflow_type, "float64") + logger.info(f"Using model signature to cast CSV columns: {csv_dtype}") + + logger.info("Running predictions by chunks of 1000 rows") + all_predictions = [] + for chunk in pd.read_csv(local_input, chunksize=1000, dtype=csv_dtype): + predictions = model.predict(chunk) + if hasattr(predictions, "tolist"): + predictions = predictions.tolist() + all_predictions.extend(predictions) + + output_df = pd.DataFrame({"prediction": all_predictions}) + local_output = "/tmp/output.csv" + output_df.to_csv(local_output, index=False) + + logger.info(f"Uploading results to {batch_bucket}/{output_path}") + with open(local_output, "rb") as f: + s3.put_object(Bucket=batch_bucket, Key=output_path, Body=f.read()) + + logger.info(f"Batch prediction completed: {len(all_predictions)} predictions written") + + +if __name__ == "__main__": + main() diff --git a/backend/domain/entities/docker/dockerfile_template.py b/backend/domain/entities/docker/dockerfile_template.py index 949a5ab..c08e240 100644 --- a/backend/domain/entities/docker/dockerfile_template.py +++ b/backend/domain/entities/docker/dockerfile_template.py @@ -27,6 +27,7 @@ def __init__(self, python_version: str): #Copy artefacts and dependencies lists COPY custom_model /opt/mlflow COPY fast_api_template.py /opt/mlflow + COPY batch_predict_template.py /opt/mlflow # Install python model version RUN YAML_PYTHON_VERSION=$(grep -E "^ *- python=" /opt/mlflow/conda.yaml \ @@ -37,7 +38,7 @@ def __init__(self, python_version: str): # Install additional dependencies in the environment RUN uv venv RUN uv pip install -r /opt/mlflow/requirements.txt - RUN uv pip install uvicorn fastapi cloudpickle loguru mlflow python-multipart + RUN uv pip install uvicorn fastapi cloudpickle loguru mlflow python-multipart boto3 RUN uv pip install opentelemetry-api opentelemetry-sdk opentelemetry-instrumentation-fastapi \ opentelemetry-exporter-prometheus diff --git a/backend/domain/entities/docker/utils.py b/backend/domain/entities/docker/utils.py index c014d0b..400a0df 100644 --- a/backend/domain/entities/docker/utils.py +++ b/backend/domain/entities/docker/utils.py @@ -115,6 +115,18 @@ def copy_fast_api_template_to_tmp_docker_folder(dest_path: str) -> None: shutil.copy(src_path, dest_path) +def copy_batch_predict_template_to_tmp_docker_folder(dest_path: str) -> None: + """ + Copies the batch predict template to the specified destination path. + + Args: + dest_path (str): The destination path where the batch predict template will be copied. + """ + src_path = os.path.join(PROJECT_DIR, "backend/domain/entities/docker/batch_predict_template.py") + logger.info(f"Copying batch predict template from {src_path} to {dest_path}") + shutil.copy(src_path, dest_path) + + def prepare_docker_context( registry: MLFlowModelRegistryAdapter, project_name: str, model_name: str, version: str ) -> str: @@ -132,6 +144,7 @@ def prepare_docker_context( """ path_dest = create_tmp_artefacts_folder(model_name, project_name, version, path=os.path.join(PROJECT_DIR, "tmp")) copy_fast_api_template_to_tmp_docker_folder(path_dest) + copy_batch_predict_template_to_tmp_docker_folder(path_dest) registry.download_model_artifacts(model_name, version, path_dest) return path_dest @@ -169,6 +182,30 @@ def clean_build_context(context_path: str) -> None: remove_directory(context_path) +def check_docker_image_exists(image_name: str) -> bool: + docker_host = os.environ.get("DOCKER_HOST") + logger.info(f"Checking if Docker image '{image_name}' exists with batch support (DOCKER_HOST={docker_host})") + try: + result = subprocess.run(["docker", "images", "-q", image_name], capture_output=True, text=True) + if not result.stdout.strip(): + logger.info(f"Docker image '{image_name}' does not exist") + return False + + # Verify the image contains the batch predict template (old images may not have it) + batch_template_path = "/opt/mlflow/batch_predict_template.py" + check_cmd = ["docker", "run", "--rm", "--entrypoint", "test", image_name, "-f", batch_template_path] + check = subprocess.run(check_cmd, capture_output=True) + has_batch = check.returncode == 0 + if not has_batch: + logger.info(f"Docker image '{image_name}' exists but lacks batch_predict_template.py, rebuild needed") + else: + logger.info(f"Docker image '{image_name}' exists with batch support") + return has_batch + except Exception as e: + logger.warning(f"Failed to check Docker image existence: {e}") + return False + + def sanitize_name(project_name: str) -> str: """Nettoie et format le nom pour être valid dans Kubernetes.""" sanitized_name = re.sub(r"[^a-z0-9-]", "-", project_name.lower()) diff --git a/backend/domain/entities/project.py b/backend/domain/entities/project.py index 6813ec2..ef60878 100644 --- a/backend/domain/entities/project.py +++ b/backend/domain/entities/project.py @@ -8,7 +8,14 @@ class Project(BaseModel): owner: str scope: str data_perimeter: str + batch_enabled: bool = False connection_parameters: Optional[str] = None def to_json(self) -> dict: - return {"data_perimeter": self.data_perimeter, "scope": self.scope, "owner": self.owner, "name": self.name} + return { + "data_perimeter": self.data_perimeter, + "scope": self.scope, + "owner": self.owner, + "name": self.name, + "batch_enabled": self.batch_enabled, + } diff --git a/backend/domain/ports/batch_prediction_handler.py b/backend/domain/ports/batch_prediction_handler.py new file mode 100644 index 0000000..1b646e6 --- /dev/null +++ b/backend/domain/ports/batch_prediction_handler.py @@ -0,0 +1,28 @@ +# Philippe Stepniewski +from abc import ABC, abstractmethod + +from backend.domain.entities.batch_prediction import BatchPrediction + + +class BatchPredictionHandler(ABC): + @abstractmethod + def create_batch_job( + self, project_name: str, model_name: str, model_version: str, input_path: str, output_path: str, job_id: str + ) -> BatchPrediction: + pass + + @abstractmethod + def get_job_status(self, project_name: str, job_id: str) -> BatchPrediction: + pass + + @abstractmethod + def list_batch_jobs(self, project_name: str) -> list[BatchPrediction]: + pass + + @abstractmethod + def delete_batch_job(self, project_name: str, job_id: str) -> bool: + pass + + @abstractmethod + def list_finished_jobs(self, project_name: str) -> list[BatchPrediction]: + pass diff --git a/backend/domain/ports/object_storage_handler.py b/backend/domain/ports/object_storage_handler.py new file mode 100644 index 0000000..6ce206f --- /dev/null +++ b/backend/domain/ports/object_storage_handler.py @@ -0,0 +1,32 @@ +# Philippe Stepniewski +from abc import ABC, abstractmethod + + +class ObjectStorageHandler(ABC): + @abstractmethod + def ensure_project_space(self, project_name: str) -> None: + pass + + @abstractmethod + def remove_project_space(self, project_name: str) -> None: + pass + + @abstractmethod + def upload_file(self, project_name: str, remote_path: str, file_content: bytes) -> None: + pass + + @abstractmethod + def download_file(self, project_name: str, remote_path: str) -> bytes: + pass + + @abstractmethod + def list_files(self, project_name: str, prefix: str = "") -> list[str]: + pass + + @abstractmethod + def delete_file(self, project_name: str, remote_path: str) -> None: + pass + + @abstractmethod + def file_exists(self, project_name: str, remote_path: str) -> bool: + pass diff --git a/backend/domain/ports/project_db_handler.py b/backend/domain/ports/project_db_handler.py index 2d6bb04..85f4aae 100644 --- a/backend/domain/ports/project_db_handler.py +++ b/backend/domain/ports/project_db_handler.py @@ -23,3 +23,7 @@ def add_project(self, project: Project) -> bool: @abstractmethod def remove_project(self, name) -> bool: pass + + @abstractmethod + def update_batch_enabled(self, name: str, batch_enabled: bool) -> bool: + pass diff --git a/backend/domain/use_cases/batch_predict.py b/backend/domain/use_cases/batch_predict.py new file mode 100644 index 0000000..d216bd4 --- /dev/null +++ b/backend/domain/use_cases/batch_predict.py @@ -0,0 +1,118 @@ +# Philippe Stepniewski +from fastapi import HTTPException +from loguru import logger + +from backend.domain.entities.docker.utils import build_model_docker_image, check_docker_image_exists, sanitize_name +from backend.domain.ports.batch_prediction_handler import BatchPredictionHandler +from backend.domain.ports.object_storage_handler import ObjectStorageHandler +from backend.domain.ports.project_db_handler import ProjectDbHandler + + +def ensure_model_image_exists(registry, project_name: str, model_name: str, version: str): + image_name = sanitize_name(f"{project_name}_{model_name}_{version}_ctr") + if check_docker_image_exists(image_name): + logger.info(f"Docker image '{image_name}' already exists, skipping build") + return + logger.info(f"Docker image '{image_name}' not found, building...") + build_status = build_model_docker_image(registry, project_name, model_name, version) + if build_status == 0: + raise HTTPException(status_code=500, detail="Failed to build model image") + logger.info(f"Docker image '{image_name}' built successfully") + + +def submit_batch_prediction( + project_name: str, + model_name: str, + version: str, + file_content: bytes, + job_id: str, + object_storage: ObjectStorageHandler, + batch_handler: BatchPredictionHandler, + project_db_handler: ProjectDbHandler, + registry=None, +): + project = project_db_handler.get_project(project_name) + if not project.batch_enabled: + raise HTTPException(status_code=400, detail="Batch predictions are not enabled for this project") + + input_path = f"{project_name}/{model_name}/{version}/{job_id}/input.csv" + output_path = f"{project_name}/{model_name}/{version}/{job_id}/predictions-{job_id}.csv" + + logger.info(f"Uploading input file to {input_path}") + object_storage.upload_file(project_name, f"{model_name}/{version}/{job_id}/input.csv", file_content) + + if registry: + ensure_model_image_exists(registry, project_name, model_name, version) + + batch_prediction = batch_handler.create_batch_job( + project_name, model_name, version, input_path, output_path, job_id + ) + return batch_prediction.to_json() + + +def get_batch_prediction_status(project_name: str, job_id: str, batch_handler: BatchPredictionHandler): + batch_prediction = batch_handler.get_job_status(project_name, job_id) + return batch_prediction.to_json() + + +def list_batch_predictions(project_name: str, batch_handler: BatchPredictionHandler): + jobs = batch_handler.list_batch_jobs(project_name) + return [job.to_json() for job in jobs] + + +def download_batch_result( + project_name: str, + job_id: str, + batch_handler: BatchPredictionHandler, + object_storage: ObjectStorageHandler, +): + batch_prediction = batch_handler.get_job_status(project_name, job_id) + model = batch_prediction.model_name + version = batch_prediction.model_version + output_remote_path = f"{model}/{version}/{job_id}/predictions-{job_id}.csv" + content = object_storage.download_file(project_name, output_remote_path) + return content + + +def delete_batch_prediction( + project_name: str, + job_id: str, + batch_handler: BatchPredictionHandler, + object_storage: ObjectStorageHandler, +): + batch_prediction = batch_handler.get_job_status(project_name, job_id) + model_name = batch_prediction.model_name + model_version = batch_prediction.model_version + + batch_handler.delete_batch_job(project_name, job_id) + + prefix = f"{model_name}/{model_version}/{job_id}/" + files = object_storage.list_files(project_name, prefix) + for f in files: + object_storage.delete_file(project_name, f) + + logger.info(f"Deleted batch prediction {job_id} and associated files") + return True + + +def cleanup_batch_predictions( + project_name: str, + batch_handler: BatchPredictionHandler, + object_storage: ObjectStorageHandler, +): + finished_jobs = batch_handler.list_finished_jobs(project_name) + deleted_count = 0 + for job in finished_jobs: + job_id = job.job_id + batch_handler.delete_batch_job(project_name, job_id) + + prefix = f"{job.model_name}/{job.model_version}/{job_id}/" + files = object_storage.list_files(project_name, prefix) + for f in files: + object_storage.delete_file(project_name, f) + + deleted_count += 1 + logger.info(f"Cleaned up batch job {job_id} ({job.status.value})") + + logger.info(f"Cleaned up {deleted_count} finished batch jobs for project {project_name}") + return deleted_count diff --git a/backend/domain/use_cases/projects_usecases.py b/backend/domain/use_cases/projects_usecases.py index c7344a2..4b1fad5 100644 --- a/backend/domain/use_cases/projects_usecases.py +++ b/backend/domain/use_cases/projects_usecases.py @@ -1,6 +1,8 @@ +# Philippe Stepniewski from loguru import logger from backend.domain.entities.project import Project +from backend.domain.ports.object_storage_handler import ObjectStorageHandler from backend.domain.ports.project_db_handler import ProjectDbHandler from backend.domain.use_cases.deploy_registry import deploy_registry from backend.domain.use_cases.deployed_models import _remove_project_namespace @@ -22,8 +24,10 @@ def list_projects_for_user(user: str, project_db_handler: ProjectDbHandler) -> l return l_projects -def add_project(project_db_handler: ProjectDbHandler, project: Project) -> bool: +def add_project(project_db_handler: ProjectDbHandler, project: Project, object_storage: ObjectStorageHandler) -> bool: deploy_registry(project.name) + if project.batch_enabled: + object_storage.ensure_project_space(project.name) status = project_db_handler.add_project(project) return status @@ -33,10 +37,30 @@ def get_project_info(project_db_handler: ProjectDbHandler, project_name: str) -> return project -def remove_project(project_db_handler: ProjectDbHandler, project_name: str) -> bool: +def remove_project( + project_db_handler: ProjectDbHandler, project_name: str, object_storage: ObjectStorageHandler +) -> bool: try: _remove_project_namespace(project_name) except Exception as e: logger.error(f"K8s cleanup failed for project '{project_name}', continuing with DB removal: {e}") + try: + object_storage.remove_project_space(project_name) + except Exception as e: + logger.error(f"Storage cleanup failed for project '{project_name}', continuing with DB removal: {e}") project_db_handler.remove_project(project_name) return True + + +def update_project_batch_enabled( + project_db_handler: ProjectDbHandler, + project_name: str, + batch_enabled: bool, + object_storage: ObjectStorageHandler, +) -> bool: + if batch_enabled: + object_storage.ensure_project_space(project_name) + else: + object_storage.remove_project_space(project_name) + project_db_handler.update_batch_enabled(project_name, batch_enabled) + return True diff --git a/backend/infrastructure/k8s_batch_prediction_adapter.py b/backend/infrastructure/k8s_batch_prediction_adapter.py new file mode 100644 index 0000000..ba5c2fe --- /dev/null +++ b/backend/infrastructure/k8s_batch_prediction_adapter.py @@ -0,0 +1,194 @@ +# Philippe Stepniewski +from datetime import datetime, timezone + +from kubernetes import client +from kubernetes.client.rest import ApiException +from loguru import logger + +from backend.domain.entities.batch_prediction import BatchPrediction, BatchPredictionStatus +from backend.domain.ports.batch_prediction_handler import BatchPredictionHandler +from backend.infrastructure.k8s_deployment import K8SDeployment +from backend.utils import sanitize_project_name + + +class K8sBatchPredictionAdapter(BatchPredictionHandler, K8SDeployment): + def __init__(self): + super().__init__() + self.batch_api = client.BatchV1Api() + + def create_batch_job( + self, project_name: str, model_name: str, model_version: str, input_path: str, output_path: str, job_id: str + ) -> BatchPrediction: + namespace = sanitize_project_name(project_name) + docker_image_name = sanitize_project_name(f"{project_name}_{model_name}_{model_version}_ctr") + + env_vars = [ + client.V1EnvVar(name="INPUT_PATH", value=input_path), + client.V1EnvVar(name="OUTPUT_PATH", value=output_path), + client.V1EnvVar(name="BATCH_BUCKET", value="batch-predictions"), + client.V1EnvVar(name="MLFLOW_S3_ENDPOINT_URL", value=self._get_env("MLFLOW_S3_ENDPOINT_URL")), + client.V1EnvVar(name="AWS_ACCESS_KEY_ID", value=self._get_env("AWS_ACCESS_KEY_ID", "minio_user")), + client.V1EnvVar( + name="AWS_SECRET_ACCESS_KEY", value=self._get_env("AWS_SECRET_ACCESS_KEY", "minio_password") + ), + ] + + job = client.V1Job( + metadata=client.V1ObjectMeta( + name=f"batch-{job_id}", + namespace=namespace, + labels={ + "app": "batch-prediction", + "project": sanitize_project_name(project_name), + "model": sanitize_project_name(model_name), + "version": sanitize_project_name(model_version), + "model-raw": model_name, + "version-raw": model_version, + "job_id": job_id, + }, + ), + spec=client.V1JobSpec( + backoff_limit=1, + ttl_seconds_after_finished=3600, + template=client.V1PodTemplateSpec( + metadata=client.V1ObjectMeta( + labels={ + "app": "batch-prediction", + "job_id": job_id, + } + ), + spec=client.V1PodSpec( + containers=[ + client.V1Container( + name="batch-predict", + image=f"{docker_image_name}:latest", + image_pull_policy="IfNotPresent", + command=["bash", "-c", "uv run python batch_predict_template.py"], + env=env_vars, + ) + ], + restart_policy="Never", + ), + ), + ), + ) + + self.batch_api.create_namespaced_job(namespace=namespace, body=job) + logger.info(f"Created batch job batch-{job_id} in namespace {namespace}") + + return BatchPrediction( + job_id=job_id, + project_name=project_name, + model_name=model_name, + model_version=model_version, + status=BatchPredictionStatus.PENDING, + input_path=input_path, + output_path=output_path, + created_at=datetime.now(timezone.utc), + ) + + def get_job_status(self, project_name: str, job_id: str) -> BatchPrediction: + namespace = sanitize_project_name(project_name) + job = self.batch_api.read_namespaced_job(name=f"batch-{job_id}", namespace=namespace) + return self._job_to_batch_prediction(job) + + def list_batch_jobs(self, project_name: str) -> list[BatchPrediction]: + namespace = sanitize_project_name(project_name) + jobs = self.batch_api.list_namespaced_job(namespace=namespace, label_selector="app=batch-prediction") + return [self._job_to_batch_prediction(job) for job in jobs.items] + + def delete_batch_job(self, project_name: str, job_id: str) -> bool: + namespace = sanitize_project_name(project_name) + try: + self.batch_api.delete_namespaced_job( + name=f"batch-{job_id}", + namespace=namespace, + body=client.V1DeleteOptions(propagation_policy="Background"), + ) + logger.info(f"Deleted batch job batch-{job_id} from namespace {namespace}") + return True + except ApiException as e: + if e.status == 404: + logger.warning(f"Batch job batch-{job_id} not found in namespace {namespace}") + return False + raise + + def list_finished_jobs(self, project_name: str) -> list[BatchPrediction]: + namespace = sanitize_project_name(project_name) + jobs = self.batch_api.list_namespaced_job(namespace=namespace, label_selector="app=batch-prediction") + finished = [] + for job in jobs.items: + bp = self._job_to_batch_prediction(job) + if bp.status in (BatchPredictionStatus.COMPLETED, BatchPredictionStatus.FAILED): + finished.append(bp) + return finished + + def _job_to_batch_prediction(self, job: client.V1Job) -> BatchPrediction: + labels = job.metadata.labels or {} + status = self._map_job_status(job.status) + + started_at = None + completed_at = None + error_message = None + + if job.status.start_time: + started_at = job.status.start_time + + if job.status.completion_time: + completed_at = job.status.completion_time + + if status == BatchPredictionStatus.FAILED: + error_message = self._get_pod_error_logs(job.metadata.namespace, job.metadata.name) + + env_vars = {} + containers = job.spec.template.spec.containers + if containers: + for env in containers[0].env or []: + env_vars[env.name] = env.value + + return BatchPrediction( + job_id=labels.get("job_id", ""), + project_name=labels.get("project", ""), + model_name=labels.get("model-raw", labels.get("model", "")), + model_version=labels.get("version-raw", labels.get("version", "")), + status=status, + input_path=env_vars.get("INPUT_PATH", ""), + output_path=env_vars.get("OUTPUT_PATH", ""), + created_at=job.metadata.creation_timestamp or datetime.now(timezone.utc), + started_at=started_at, + completed_at=completed_at, + error_message=error_message, + ) + + def _map_job_status(self, status: client.V1JobStatus) -> BatchPredictionStatus: + if status.succeeded and status.succeeded > 0: + return BatchPredictionStatus.COMPLETED + if status.failed and status.failed > 0: + return BatchPredictionStatus.FAILED + if status.active and status.active > 0: + return BatchPredictionStatus.RUNNING + return BatchPredictionStatus.PENDING + + def _get_pod_error_logs(self, namespace: str, job_name: str) -> str: + try: + pods = self.service_api_instance.list_namespaced_pod( + namespace=namespace, label_selector=f"job-name={job_name}" + ) + if not pods.items: + return "No pod found for this job" + pod = pods.items[0] + logs = self.service_api_instance.read_namespaced_pod_log( + name=pod.metadata.name, namespace=namespace, tail_lines=30 + ) + lines = [line for line in logs.strip().splitlines() if line.strip()] + if not lines: + return "No logs available" + return "\n".join(lines[-5:]) + except Exception as e: + logger.warning(f"Could not fetch pod logs for job {job_name}: {e}") + return "Could not retrieve error details" + + def _get_env(self, key: str, default: str = "") -> str: + import os + + return os.environ.get(key, default) diff --git a/backend/infrastructure/minio_storage_adapter.py b/backend/infrastructure/minio_storage_adapter.py new file mode 100644 index 0000000..bf02f0f --- /dev/null +++ b/backend/infrastructure/minio_storage_adapter.py @@ -0,0 +1,78 @@ +# Philippe Stepniewski +import os + +import boto3 +from botocore.exceptions import ClientError +from loguru import logger + +from backend.domain.ports.object_storage_handler import ObjectStorageHandler + +BATCH_BUCKET = "batch-predictions" + + +class MinioStorageAdapter(ObjectStorageHandler): + def __init__(self): + self.s3 = boto3.client( + "s3", + endpoint_url=os.environ["MLFLOW_S3_ENDPOINT_URL"], + aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID", "minio_user"), + aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY", "minio_password"), + ) + + def _ensure_bucket(self): + try: + self.s3.head_bucket(Bucket=BATCH_BUCKET) + except ClientError: + self.s3.create_bucket(Bucket=BATCH_BUCKET) + logger.info(f"Created bucket '{BATCH_BUCKET}'") + + def ensure_project_space(self, project_name: str) -> None: + self._ensure_bucket() + try: + self.s3.put_object(Bucket=BATCH_BUCKET, Key=f"{project_name}/.keep", Body=b"") + except ClientError as e: + logger.warning(f"Could not create marker for '{project_name}': {e}") + logger.info(f"Ensured project space for '{project_name}' in bucket '{BATCH_BUCKET}'") + + def remove_project_space(self, project_name: str) -> None: + prefix = f"{project_name}/" + paginator = self.s3.get_paginator("list_objects_v2") + for page in paginator.paginate(Bucket=BATCH_BUCKET, Prefix=prefix): + objects = page.get("Contents", []) + if objects: + self.s3.delete_objects( + Bucket=BATCH_BUCKET, + Delete={"Objects": [{"Key": obj["Key"]} for obj in objects]}, + ) + logger.info(f"Removed project space for '{project_name}' from bucket '{BATCH_BUCKET}'") + + def upload_file(self, project_name: str, remote_path: str, file_content: bytes) -> None: + self._ensure_bucket() + key = f"{project_name}/{remote_path}" + self.s3.put_object(Bucket=BATCH_BUCKET, Key=key, Body=file_content) + + def download_file(self, project_name: str, remote_path: str) -> bytes: + key = f"{project_name}/{remote_path}" + response = self.s3.get_object(Bucket=BATCH_BUCKET, Key=key) + return response["Body"].read() + + def list_files(self, project_name: str, prefix: str = "") -> list[str]: + full_prefix = f"{project_name}/{prefix}" + paginator = self.s3.get_paginator("list_objects_v2") + files = [] + for page in paginator.paginate(Bucket=BATCH_BUCKET, Prefix=full_prefix): + for obj in page.get("Contents", []): + files.append(obj["Key"].removeprefix(f"{project_name}/")) + return files + + def delete_file(self, project_name: str, remote_path: str) -> None: + key = f"{project_name}/{remote_path}" + self.s3.delete_object(Bucket=BATCH_BUCKET, Key=key) + + def file_exists(self, project_name: str, remote_path: str) -> bool: + key = f"{project_name}/{remote_path}" + try: + self.s3.head_object(Bucket=BATCH_BUCKET, Key=key) + return True + except ClientError: + return False diff --git a/backend/infrastructure/project_pgsql_db_handler.py b/backend/infrastructure/project_pgsql_db_handler.py index 8f8b212..1a2fe45 100644 --- a/backend/infrastructure/project_pgsql_db_handler.py +++ b/backend/infrastructure/project_pgsql_db_handler.py @@ -68,10 +68,12 @@ def add_project(self, project: Project) -> bool: try: cursor = connection.cursor() query = """ - INSERT INTO projects (name, owner, scope, data_perimeter) - VALUES (%s, %s, %s, %s) \ + INSERT INTO projects (name, owner, scope, data_perimeter, batch_enabled) + VALUES (%s, %s, %s, %s, %s) \ """ - cursor.execute(query, (project.name, project.owner, project.scope, project.data_perimeter)) + cursor.execute( + query, (project.name, project.owner, project.scope, project.data_perimeter, project.batch_enabled) + ) connection.commit() finally: connection.close() @@ -87,6 +89,16 @@ def remove_project(self, name): connection.close() return True + def update_batch_enabled(self, name: str, batch_enabled: bool) -> bool: + connection = self._connect() + try: + cursor = connection.cursor() + cursor.execute("UPDATE projects SET batch_enabled = %s WHERE name = %s", (batch_enabled, name)) + connection.commit() + finally: + connection.close() + return True + def _init_table_project_if_not_exists(self): connection = self._connect() try: @@ -117,6 +129,7 @@ def _init_table_project_if_not_exists(self): ) \ """ cursor.execute(query) + cursor.execute("ALTER TABLE projects ADD COLUMN IF NOT EXISTS batch_enabled BOOLEAN DEFAULT FALSE") connection.commit() finally: connection.close() diff --git a/backend/infrastructure/project_sqlite_db_handler.py b/backend/infrastructure/project_sqlite_db_handler.py index 1b3390b..91557aa 100644 --- a/backend/infrastructure/project_sqlite_db_handler.py +++ b/backend/infrastructure/project_sqlite_db_handler.py @@ -18,7 +18,16 @@ def __init__(self, message, name=None): def map_rows_to_projects(rows: list) -> list[Project]: - return [Project(name=row[1], owner=row[2], scope=row[3], data_perimeter=row[4]) for row in rows] + return [ + Project( + name=row[1], + owner=row[2], + scope=row[3], + data_perimeter=row[4], + batch_enabled=bool(row[5]) if len(row) > 5 else False, + ) + for row in rows + ] class ProjectSQLiteDBHandler(ProjectDbHandler): @@ -73,10 +82,10 @@ def add_project(self, project: Project) -> bool: cursor = connection.cursor() cursor.execute( """ - INSERT INTO projects (name, owner, scope, data_perimeter) - VALUES (?, ?, ?, ?) + INSERT INTO projects (name, owner, scope, data_perimeter, batch_enabled) + VALUES (?, ?, ?, ?, ?) """, - (project.name, project.owner, project.scope, project.data_perimeter), + (project.name, project.owner, project.scope, project.data_perimeter, project.batch_enabled), ) connection.commit() @@ -94,6 +103,16 @@ def remove_project(self, name): connection.close() return True + def update_batch_enabled(self, name: str, batch_enabled: bool) -> bool: + connection = sqlite3.connect(self.db_path) + try: + cursor = connection.cursor() + cursor.execute("UPDATE projects SET batch_enabled = ? WHERE name = ?", (batch_enabled, name)) + connection.commit() + finally: + connection.close() + return True + def _init_table_project_if_not_exists(self): connection = sqlite3.connect(self.db_path) try: @@ -106,10 +125,16 @@ def _init_table_project_if_not_exists(self): name TEXT NOT NULL, owner TEXT NOT NULL, scope TEXT NOT NULL, - data_perimeter TEXT NOT NULL + data_perimeter TEXT NOT NULL, + batch_enabled BOOLEAN DEFAULT FALSE ) """ ) + # Migration: add batch_enabled column if missing + cursor.execute("PRAGMA table_info(projects)") + columns = [col[1] for col in cursor.fetchall()] + if "batch_enabled" not in columns: + cursor.execute("ALTER TABLE projects ADD COLUMN batch_enabled BOOLEAN DEFAULT FALSE") connection.commit() finally: connection.close() diff --git a/cli/commands/batch.py b/cli/commands/batch.py new file mode 100644 index 0000000..e7839dc --- /dev/null +++ b/cli/commands/batch.py @@ -0,0 +1,69 @@ +# Philippe Stepniewski +import typer + +from cli.utils.api_calls import get_and_print +from cli.utils.token import get_client + + +def submit_batch( + project_name: str, + model_name: str, + version: str, + file_path: str = typer.Option(..., help="Path to the CSV file to process"), +): + """Submit a batch prediction job""" + client = get_client() + with open(file_path, "rb") as f: + r = client.post( + f"/{project_name}/batch/submit/{model_name}/{version}", + files={"file": (file_path.split("/")[-1], f, "text/csv")}, + ) + if r.status_code == 200: + result = r.json() + print(f"Batch job submitted successfully. Job ID: {result.get('job_id', 'unknown')}") + else: + print(f"Error submitting batch job: {r.text}") + + +def batch_status(project_name: str, job_id: str): + """Get the status of a batch prediction job""" + get_and_print( + f"/{project_name}/batch/status/{job_id}", + error_message="Error fetching batch job status", + success_message="Batch job status retrieved", + ) + + +def list_batch_jobs(project_name: str): + """List all batch prediction jobs for a project""" + get_and_print( + f"/{project_name}/batch/list", + error_message="Error listing batch jobs", + success_message="No batch jobs found", + ) + + +def download_batch_result( + project_name: str, + job_id: str, + output: str = typer.Option("predictions.csv", help="Output file path"), +): + """Download the result of a batch prediction job""" + client = get_client() + r = client.get(f"/{project_name}/batch/download/{job_id}") + if r.status_code == 200: + with open(output, "wb") as f: + f.write(r.content) + print(f"Results downloaded to {output}") + else: + print(f"Error downloading batch result: {r.text}") + + +def delete_batch_job(project_name: str, job_id: str): + """Delete a batch prediction job and its files""" + client = get_client() + r = client.delete(f"/{project_name}/batch/{job_id}") + if r.status_code == 200: + print("Batch job deleted successfully") + else: + print(f"Error deleting batch job: {r.text}") diff --git a/cli/commands/projects.py b/cli/commands/projects.py index 475ade2..d5dd0bc 100644 --- a/cli/commands/projects.py +++ b/cli/commands/projects.py @@ -1,6 +1,7 @@ +# Philippe Stepniewski import typer -from cli.utils.api_calls import get_and_print, post_and_print +from cli.utils.api_calls import get_and_print, patch_and_print, post_and_print from cli.utils.token import get_client @@ -13,8 +14,8 @@ def project_info(name: str): """Get detailed info about a project by name""" get_and_print( f"/projects/{name}/info", - error_message="❌ Error fetching project info", - success_message="✅ Project info retrieved successfully", + error_message="Error fetching project info", + success_message="Project info retrieved successfully", ) @@ -23,13 +24,21 @@ def add_project( owner: str = typer.Option(""), scope: str = typer.Option(""), data_perimeter: str = typer.Option(""), + batch_enabled: bool = typer.Option(False, help="Enable batch predictions for this project"), ): """Create a new project""" + payload = { + "name": name, + "owner": owner, + "scope": scope, + "data_perimeter": data_perimeter, + "batch_enabled": batch_enabled, + } post_and_print( "/projects/add", - {"name": name, "owner": owner, "scope": scope, "data_perimeter": data_perimeter}, - error_message="❌ Error creating project", - success_message="✅ Project created successfully", + payload, + error_message="Error creating project", + success_message="Project created successfully", ) @@ -37,8 +46,28 @@ def delete_project(name: str): """Delete a project by name""" get_and_print( f"/projects/{name}/remove", - error_message="❌ Error deleting project", - success_message="✅ Project deleted successfully", + error_message="Error deleting project", + success_message="Project deleted successfully", + ) + + +def enable_batch(project_name: str): + """Enable batch predictions for a project""" + patch_and_print( + f"/projects/{project_name}/batch_enabled", + {"batch_enabled": True}, + error_message="Error enabling batch predictions", + success_message="Batch predictions enabled successfully", + ) + + +def disable_batch(project_name: str): + """Disable batch predictions for a project""" + patch_and_print( + f"/projects/{project_name}/batch_enabled", + {"batch_enabled": False}, + error_message="Error disabling batch predictions", + success_message="Batch predictions disabled successfully", ) @@ -47,7 +76,7 @@ def add_user_to_project(project_name: str, email: str = typer.Option(), role: st client = get_client() r = client.post(f"/projects/{project_name}/add_user?email={email}&role={role}") if r.status_code == 200: - print("✅ User added to project successfully") + print("User added to project successfully") else: print(r.content) - print("❌ Error adding user to project") + print("Error adding user to project") diff --git a/cli/main.py b/cli/main.py index f042c41..ba1706d 100644 --- a/cli/main.py +++ b/cli/main.py @@ -1,18 +1,29 @@ import typer from cli.commands.auth import login, me +from cli.commands.batch import batch_status, delete_batch_job, download_batch_result, list_batch_jobs, submit_batch from cli.commands.demo import get_status, list_simulations, start_simulation, stop_simulation from cli.commands.models import deploy_model, list_deployed_models, list_models, search_model_infos, undeploy_model -from cli.commands.projects import add_project, add_user_to_project, delete_project, list_projects, project_info +from cli.commands.projects import ( + add_project, + add_user_to_project, + delete_project, + disable_batch, + enable_batch, + list_projects, + project_info, +) from cli.commands.users import add_user, list_users app = typer.Typer() project_app = typer.Typer() user_app = typer.Typer() demo_app = typer.Typer() +batch_app = typer.Typer() app.add_typer(project_app, name="projects") app.add_typer(user_app, name="users") app.add_typer(demo_app, name="demo") +app.add_typer(batch_app, name="batch") app.command()(login) app.command()(me) @@ -27,11 +38,18 @@ project_app.command("undeploy")(undeploy_model) project_app.command("list-deployed-models")(list_deployed_models) project_app.command("delete")(delete_project) +project_app.command("enable-batch")(enable_batch) +project_app.command("disable-batch")(disable_batch) user_app.command("list")(list_users) user_app.command("add")(add_user) demo_app.command("list")(list_simulations) demo_app.command("start")(start_simulation) demo_app.command("stop")(stop_simulation) demo_app.command("status")(get_status) +batch_app.command("submit")(submit_batch) +batch_app.command("status")(batch_status) +batch_app.command("list")(list_batch_jobs) +batch_app.command("download")(download_batch_result) +batch_app.command("delete")(delete_batch_job) if __name__ == "__main__": app() diff --git a/cli/utils/api_calls.py b/cli/utils/api_calls.py index e235660..b6c1299 100644 --- a/cli/utils/api_calls.py +++ b/cli/utils/api_calls.py @@ -45,3 +45,15 @@ def post_and_print( else: print(r.content) print(error_message) + + +def patch_and_print( + endpoint: str, payload: dict, error_message: str = "❌ Error", success_message: str = "✅Success" +) -> None: + client = get_client() + r = client.patch(endpoint, json=payload) + if r.status_code == 200: + print(success_message) + else: + print(r.content) + print(error_message) diff --git a/demos/batch_prediction_test_credit_default.csv b/demos/batch_prediction_test_credit_default.csv new file mode 100644 index 0000000..ae67f3c --- /dev/null +++ b/demos/batch_prediction_test_credit_default.csv @@ -0,0 +1,21 @@ +age,income,loan_amount,loan_duration_months,credit_score,num_existing_loans,employment_years,missed_payments_12m,debt_to_income_ratio,loan_to_income_ratio +35,55000,15000,36,720,1,8,0,0.2727,0.2727 +28,32000,8000,24,650,2,3,1,0.25,0.25 +45,85000,40000,60,780,0,18,0,0.4706,0.4706 +52,120000,60000,84,810,1,25,0,0.5,0.5 +23,22000,5000,12,580,3,1,3,0.2273,0.2273 +61,95000,30000,48,750,0,30,0,0.3158,0.3158 +30,40000,20000,36,690,2,5,0,0.5,0.5 +42,70000,35000,60,710,1,15,1,0.5,0.5 +55,110000,25000,24,830,0,28,0,0.2273,0.2273 +26,28000,12000,36,620,4,2,2,0.4286,0.4286 +38,60000,18000,48,740,1,10,0,0.3,0.3 +48,90000,50000,84,700,2,20,1,0.5556,0.5556 +33,45000,10000,24,680,0,7,0,0.2222,0.2222 +29,35000,15000,36,640,3,4,2,0.4286,0.4286 +57,105000,20000,12,800,0,32,0,0.1905,0.1905 +40,65000,30000,60,730,1,12,0,0.4615,0.4615 +24,25000,7000,24,600,2,1,1,0.28,0.28 +50,100000,45000,84,770,0,22,0,0.45,0.45 +36,52000,22000,48,710,3,9,1,0.4231,0.4231 +44,78000,16000,36,760,0,16,0,0.2051,0.2051 diff --git a/frontend/js/api.js b/frontend/js/api.js index 23913bc..3d86fd5 100644 --- a/frontend/js/api.js +++ b/frontend/js/api.js @@ -82,6 +82,8 @@ const API = (() => { headers: { Authorization: `Bearer ${Auth.getToken()}` }, }).then(r => r.blob()), + updateBatchEnabled: (name, enabled) => + request('PATCH', `/projects/${name}/batch_enabled`, { body: { batch_enabled: enabled } }), registryStatus: (name) => get(`/projects/${name}/registry_status`).then(r => r.status).catch(() => 'error'), // User management @@ -223,6 +225,34 @@ const API = (() => { }), }, + // ── Batch Predictions ────────────────────────────────────── + batch: { + submit(proj, modelName, version, file) { + const fd = new FormData(); + fd.append('file', file); + return fetch(`${API_BASE}/${enc(proj)}/batch/submit/${enc(modelName)}/${enc(version)}`, { + method: 'POST', + headers: { Authorization: `Bearer ${Auth.getToken()}` }, + body: fd, + }).then(async r => { + if (!r.ok) throw new Error((await r.json()).detail || r.statusText); + return r.json(); + }); + }, + status: (proj, jobId) => get(`/${enc(proj)}/batch/status/${enc(jobId)}`), + list: (proj) => get(`/${enc(proj)}/batch/list`), + download(proj, jobId) { + return fetch(`${API_BASE}/${enc(proj)}/batch/download/${enc(jobId)}`, { + headers: { Authorization: `Bearer ${Auth.getToken()}` }, + }).then(async r => { + if (!r.ok) throw new Error((await r.json()).detail || r.statusText); + return r.blob(); + }); + }, + delete: (proj, jobId) => request('DELETE', `/${enc(proj)}/batch/${enc(jobId)}`), + cleanup: (proj) => request('POST', `/${enc(proj)}/batch/cleanup`), + }, + // ── Health ──────────────────────────────────────────────── health: { check: () => get('/health'), diff --git a/frontend/js/pages/project-detail.js b/frontend/js/pages/project-detail.js index 49d60c3..5a3ea19 100644 --- a/frontend/js/pages/project-detail.js +++ b/frontend/js/pages/project-detail.js @@ -39,12 +39,17 @@ const ProjectDetailPage = (() => { Deployed Models +
+
`; @@ -84,6 +89,7 @@ const ProjectDetailPage = (() => { case 'models': loadModels(projectName, panel); break; case 'registry': loadRegistry(projectName, panel); break; case 'deployed': loadDeployed(projectName, panel); break; + case 'batch': loadBatch(projectName, panel); break; } } @@ -107,6 +113,7 @@ const ProjectDetailPage = (() => { const owner = info.owner || info.Owner || '—'; const scope = info.scope || info.Scope || '—'; const perimeter = info.data_perimeter || info['Data Perimeter'] || '—'; + const batchEnabled = info.batch_enabled || false; panel.innerHTML = `
@@ -128,6 +135,19 @@ const ProjectDetailPage = (() => {
+
+
+ Batch Predictions +
+
+ + ${batchEnabled ? 'Enabled' : 'Disabled'} +
+
+
Users Access @@ -145,6 +165,34 @@ const ProjectDetailPage = (() => { renderUsersTable(projectName, users, document.getElementById('users-table-area')); + document.getElementById('batch-toggle').addEventListener('change', async (e) => { + const enabled = e.target.checked; + const statusEl = document.getElementById('batch-status'); + + if (!enabled) { + const ok = await Modal.confirm({ + title: 'Disable Batch Predictions', + message: `This will permanently delete all batch prediction files stored for project ${escHtml(projectName)}. This action cannot be undone.`, + confirmLabel: 'Disable & Delete', + danger: true, + }); + if (!ok) { + e.target.checked = true; + return; + } + } + + try { + await API.projects.updateBatchEnabled(projectName, enabled); + statusEl.className = `badge ${enabled ? 'badge-running' : 'badge-neutral'}`; + statusEl.textContent = enabled ? 'Enabled' : 'Disabled'; + Toast.success(`Batch predictions ${enabled ? 'enabled' : 'disabled'}.`); + } catch (err) { + e.target.checked = !enabled; + Toast.error(err.message); + } + }); + document.getElementById('add-user-btn').addEventListener('click', () => toggleAddUserForm(projectName) ); @@ -651,6 +699,282 @@ const ProjectDetailPage = (() => { }); } + // ── Batch Predictions Tab ──────────────────────────────────── + + const batchPollingTimers = {}; + + async function loadBatch(projectName, panel) { + panel.innerHTML = loadingHTML(); + try { + const info = await API.projects.info(projectName); + const batchEnabled = info.batch_enabled || false; + + if (!batchEnabled) { + panel.innerHTML = emptyHTML( + 'Batch Predictions Disabled', + 'Enable batch predictions in the Settings tab to use this feature.' + ); + return; + } + + const [models, jobs] = await Promise.all([ + API.models.list(projectName).catch(() => []), + API.batch.list(projectName).catch(() => []), + ]); + + // Fetch all versions for each model + const modelsWithVersions = await Promise.all( + (models || []).map(async m => { + try { + const versions = await API.models.versions(projectName, m.name); + return { ...m, all_versions: versions }; + } catch { + return { ...m, all_versions: m.latest_versions || [] }; + } + }) + ); + renderBatch(projectName, modelsWithVersions, jobs, panel); + } catch (err) { + panel.innerHTML = errorHTML(err.message); + } + } + + function renderBatch(projectName, models, jobs, panel) { + const modelOptions = (models || []).flatMap(m => { + const name = m.name || ''; + const versionObjects = m.all_versions || m.latest_versions || []; + const versionNumbers = versionObjects + .map(v => (typeof v === 'object' ? v.version : v)) + .filter(Boolean) + .sort((a, b) => Number(b) - Number(a)); + return versionNumbers.map(v => + `` + ); + }).join(''); + + const noModelsMsg = (!models || models.length === 0) + ? '

No models available. Register models in MLflow to submit batch predictions.

' + : ''; + + const rows = (jobs || []).map(j => batchJobRow(j, projectName)).join(''); + + panel.innerHTML = ` +
+
+ Submit Batch Prediction +
+
+ ${noModelsMsg} + ${models && models.length > 0 ? ` +
+ + + No file selected + +
` : ''} +
+
+ +
+
+ Batch Jobs + ${(jobs || []).length} + + +
+
+ ${rows ? ` +
+ + + ${rows} +
Job IDModelVersionStatusCreatedActions
+
` : emptyHTML('No batch jobs', 'Submit a batch prediction to get started.')} +
+
+ `; + + // File input handling + const fileInput = document.getElementById('batch-file-input'); + const submitBtn = document.getElementById('batch-submit-btn'); + if (fileInput) { + fileInput.addEventListener('change', () => { + const fileName = fileInput.files[0]?.name || 'No file selected'; + document.getElementById('batch-file-name').textContent = fileName; + submitBtn.disabled = !fileInput.files[0]; + }); + } + + // Submit handling + if (submitBtn) { + submitBtn.addEventListener('click', async () => { + const select = document.getElementById('batch-model-select'); + const [modelName, version] = select.value.split(':'); + const file = fileInput.files[0]; + if (!file) return; + + submitBtn.disabled = true; + submitBtn.innerHTML = ''; + try { + const result = await API.batch.submit(projectName, modelName, version, file); + Toast.success(`Batch job submitted: ${result.job_id}`); + loadBatch(projectName, panel); + } catch (err) { + Toast.error(`Submit failed: ${err.message}`); + submitBtn.innerHTML = 'Submit'; + submitBtn.disabled = false; + } + }); + } + + // Refresh + document.getElementById('batch-refresh-btn')?.addEventListener('click', () => + loadBatch(projectName, panel) + ); + + // Cleanup finished/failed jobs + document.getElementById('batch-cleanup-btn')?.addEventListener('click', async () => { + const ok = await Modal.confirm({ + title: 'Cleanup Batch Jobs', + message: 'Delete all completed and failed batch jobs and their associated files?', + confirmLabel: 'Cleanup', + danger: true, + }); + if (ok) { + try { + const result = await API.batch.cleanup(projectName); + Toast.success(`Cleaned up ${result.deleted} batch job(s).`); + loadBatch(projectName, panel); + } catch (err) { Toast.error(err.message); } + } + }); + + // Delete buttons + panel.querySelectorAll('.batch-delete-btn').forEach(btn => { + btn.addEventListener('click', async () => { + const jobId = btn.dataset.jobId; + const ok = await Modal.confirm({ + title: 'Delete Batch Job', + message: `Delete batch job ${escHtml(jobId)} and all associated files?`, + confirmLabel: 'Delete', + danger: true, + }); + if (ok) { + try { + await API.batch.delete(projectName, jobId); + Toast.success('Batch job deleted.'); + loadBatch(projectName, panel); + } catch (err) { Toast.error(err.message); } + } + }); + }); + + // Error detail buttons + panel.querySelectorAll('.batch-error-btn').forEach(btn => { + btn.addEventListener('click', () => { + Modal.confirm({ + title: 'Batch Job Error', + message: `
${escHtml(btn.dataset.error)}
`, + confirmLabel: 'OK', + }); + }); + }); + + // Download buttons + panel.querySelectorAll('.batch-download-btn').forEach(btn => { + btn.addEventListener('click', async () => { + const jobId = btn.dataset.jobId; + try { + const blob = await API.batch.download(projectName, jobId); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = `predictions-${jobId}.csv`; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(url); + } catch (err) { Toast.error(err.message); } + }); + }); + + // Poll for building/pending/running jobs + (jobs || []).forEach(j => { + const s = (j.status || '').toLowerCase(); + if (s === 'building' || s === 'pending' || s === 'running') { + pollBatchStatus(projectName, j.job_id, panel); + } + }); + } + + function batchJobRow(j, projectName) { + const jobId = j.job_id || '—'; + const model = j.model_name || '—'; + const version = j.model_version || '—'; + const status = j.status || 'unknown'; + const created = j.created_at ? new Date(j.created_at).toLocaleString() : '—'; + const isCompleted = status.toLowerCase() === 'completed'; + const isFailed = status.toLowerCase() === 'failed'; + const errorMsg = j.error_message || ''; + + return ` + + ${escHtml(jobId)} + ${escHtml(model)} + ${escHtml(String(version))} + + ${statusBadge(status)} + ${isFailed && errorMsg ? `` : ''} + + ${escHtml(created)} + +
+ ${isCompleted ? `` : ''} + +
+ + `; + } + + function pollBatchStatus(projectName, jobId, panel) { + if (batchPollingTimers[jobId]) return; + const timer = setInterval(async () => { + try { + const result = await API.batch.status(projectName, jobId); + const state = (result.status || '').toLowerCase(); + if (state === 'completed' || state === 'failed') { + clearInterval(timer); + delete batchPollingTimers[jobId]; + loadBatch(projectName, panel); + if (state === 'completed') { + Toast.success(`Batch job ${jobId} completed.`); + } else { + Toast.error(`Batch job ${jobId} failed.`); + } + } + } catch { + clearInterval(timer); + delete batchPollingTimers[jobId]; + } + }, 3000); + batchPollingTimers[jobId] = timer; + } + // ── Helpers ────────────────────────────────────────────────── function statusBadge(status) { @@ -683,6 +1007,7 @@ const ProjectDetailPage = (() => { function clearPolling() { Object.values(pollingTimers).forEach(clearInterval); + Object.values(batchPollingTimers).forEach(clearInterval); } return { render }; diff --git a/frontend/js/pages/projects.js b/frontend/js/pages/projects.js index 5e34f58..0e93fd3 100644 --- a/frontend/js/pages/projects.js +++ b/frontend/js/pages/projects.js @@ -262,6 +262,10 @@ const ProjectsPage = (() => {
+
+ + +
`, footer: ` @@ -302,7 +306,8 @@ const ProjectsPage = (() => { btn.innerHTML = ' Creating…'; try { - await API.projects.add({ name, owner, scope, data_perimeter: perimeter }); + const batchEnabled = document.getElementById('new-proj-batch').checked; + await API.projects.add({ name, owner, scope, data_perimeter: perimeter, batch_enabled: batchEnabled }); close(); Toast.success(`Project "${name}" created.`); loadProjects(); diff --git a/infrastructure/k8s/backend-deployment.yaml b/infrastructure/k8s/backend-deployment.yaml index d48bd73..3c9ec65 100644 --- a/infrastructure/k8s/backend-deployment.yaml +++ b/infrastructure/k8s/backend-deployment.yaml @@ -72,6 +72,9 @@ rules: - apiGroups: [""] resources: ["pods"] verbs: [ "get", "list", "delete" ] +- apiGroups: [""] + resources: ["pods/log"] + verbs: [ "get" ] - apiGroups: [""] resources: ["services"] verbs: [ "get", "list", "create", "update", "patch", "delete" ] diff --git a/infrastructure/k8s/minio-deployment.yaml b/infrastructure/k8s/minio-deployment.yaml index daf0f8e..844c8f4 100644 --- a/infrastructure/k8s/minio-deployment.yaml +++ b/infrastructure/k8s/minio-deployment.yaml @@ -32,6 +32,8 @@ spec: args: - server - /data + - --console-address + - ":9001" env: - name: MINIO_ROOT_USER value: "minio_user" @@ -39,6 +41,7 @@ spec: value: "minio_password" ports: - containerPort: 9000 + - containerPort: 9001 volumeMounts: - name: minio-data mountPath: /data @@ -59,6 +62,9 @@ spec: - name: api port: 9000 targetPort: 9000 + - name: console + port: 9001 + targetPort: 9001 --- apiVersion: batch/v1 kind: Job diff --git a/pyproject.toml b/pyproject.toml index a8c29d5..dee2125 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ dev = [ "pytest-asyncio>=0.25.3", "pytest-mock>=3.14.0", "pytest-playwright", + "pytest-order>=1.2.0", ] notebooks = [ "pyomo>=6.9.0", diff --git a/tests/conftest.py b/tests/conftest.py index 5c32f24..744e65d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -55,13 +55,28 @@ def cleanup_project(project_name: str) -> None: # Try CLI deletion first (this cleans DB + namespace) run_cli("projects", "delete", project_name) - # Force delete namespace if it still exists + # Delete namespace and wait for it to be fully removed subprocess.run( - ["kubectl", "delete", "namespace", project_name, "--ignore-not-found", "--wait=false"], + ["kubectl", "delete", "namespace", project_name, "--ignore-not-found"], capture_output=True, text=True, + timeout=120, ) + # Poll until namespace is actually gone (handles edge cases) + deadline = time.time() + 120 + while time.time() < deadline: + result = subprocess.run( + ["kubectl", "get", "namespace", project_name], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode != 0: + return + time.sleep(5) + print(f"[WARNING] Namespace {project_name} still exists after 120s cleanup timeout") + def cleanup_test_namespaces(project_names: list[str] | None = None) -> None: """Cleanup all test namespaces.""" diff --git a/tests/test_unitaires/adapters/test_k8s_batch_prediction_adapter.py b/tests/test_unitaires/adapters/test_k8s_batch_prediction_adapter.py new file mode 100644 index 0000000..3cc7aa9 --- /dev/null +++ b/tests/test_unitaires/adapters/test_k8s_batch_prediction_adapter.py @@ -0,0 +1,129 @@ +# Philippe Stepniewski +import os +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest + +os.environ.setdefault("PATH_LOG_EVENTS", "/tmp/test_log_events") +os.environ.setdefault("MP_HOST_NAME", "localhost") +os.environ.setdefault("MP_DEPLOYMENT_PATH", "/deploy") +os.environ.setdefault("MP_DEPLOYMENT_PORT", "8000") +os.environ.setdefault("MLFLOW_S3_ENDPOINT_URL", "http://localhost:9000") +os.environ.setdefault("AWS_ACCESS_KEY_ID", "minio_user") +os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "minio_password") + + +from backend.domain.entities.batch_prediction import BatchPredictionStatus + + +@pytest.fixture +def mock_k8s(): + with ( + patch("backend.infrastructure.k8s_deployment.config") as mock_config, + patch("backend.infrastructure.k8s_deployment.client") as mock_base_client, + patch("backend.infrastructure.k8s_batch_prediction_adapter.client") as mock_client, + ): + mock_config.load_kube_config.return_value = None + mock_base_client.CoreV1Api.return_value = MagicMock() + mock_base_client.AppsV1Api.return_value = MagicMock() + mock_base_client.NetworkingV1Api.return_value = MagicMock() + mock_batch_api = MagicMock() + mock_client.BatchV1Api.return_value = mock_batch_api + mock_client.V1Job = MagicMock() + mock_client.V1ObjectMeta = MagicMock() + mock_client.V1JobSpec = MagicMock() + mock_client.V1PodTemplateSpec = MagicMock() + mock_client.V1PodSpec = MagicMock() + mock_client.V1Container = MagicMock() + mock_client.V1EnvVar = MagicMock() + mock_client.V1DeleteOptions = MagicMock() + + yield mock_batch_api, mock_client + + +def test_create_batch_job(mock_k8s): + mock_batch_api, mock_client = mock_k8s + + from backend.infrastructure.k8s_batch_prediction_adapter import K8sBatchPredictionAdapter + + adapter = K8sBatchPredictionAdapter() + + result = adapter.create_batch_job( + project_name="test-project", + model_name="my-model", + model_version="1", + input_path="test-project/my-model/1/abc/input.csv", + output_path="test-project/my-model/1/abc/predictions-abc.csv", + job_id="abc12345", + ) + + mock_batch_api.create_namespaced_job.assert_called_once() + assert result.project_name == "test-project" + assert result.model_name == "my-model" + assert result.model_version == "1" + assert result.status == BatchPredictionStatus.PENDING + + +def test_map_job_status_succeeded(mock_k8s): + from backend.infrastructure.k8s_batch_prediction_adapter import K8sBatchPredictionAdapter + + adapter = K8sBatchPredictionAdapter() + + status = MagicMock() + status.succeeded = 1 + status.failed = None + status.active = None + assert adapter._map_job_status(status) == BatchPredictionStatus.COMPLETED + + +def test_map_job_status_failed(mock_k8s): + from backend.infrastructure.k8s_batch_prediction_adapter import K8sBatchPredictionAdapter + + adapter = K8sBatchPredictionAdapter() + + status = MagicMock() + status.succeeded = None + status.failed = 1 + status.active = None + assert adapter._map_job_status(status) == BatchPredictionStatus.FAILED + + +def test_map_job_status_active(mock_k8s): + from backend.infrastructure.k8s_batch_prediction_adapter import K8sBatchPredictionAdapter + + adapter = K8sBatchPredictionAdapter() + + status = MagicMock() + status.succeeded = None + status.failed = None + status.active = 1 + assert adapter._map_job_status(status) == BatchPredictionStatus.RUNNING + + +def test_map_job_status_pending(mock_k8s): + from backend.infrastructure.k8s_batch_prediction_adapter import K8sBatchPredictionAdapter + + adapter = K8sBatchPredictionAdapter() + + status = MagicMock() + status.succeeded = None + status.failed = None + status.active = None + assert adapter._map_job_status(status) == BatchPredictionStatus.PENDING + + +def test_delete_batch_job(mock_k8s): + mock_batch_api, mock_client = mock_k8s + + from backend.infrastructure.k8s_batch_prediction_adapter import K8sBatchPredictionAdapter + + adapter = K8sBatchPredictionAdapter() + + result = adapter.delete_batch_job("test-project", "my-job-id") + + mock_batch_api.delete_namespaced_job.assert_called_once() + call_kwargs = mock_batch_api.delete_namespaced_job.call_args + assert call_kwargs[1]["name"] == "batch-my-job-id" + assert call_kwargs[1]["namespace"] == "test-project" + assert result is True diff --git a/tests/test_unitaires/adapters/test_minio_storage_adapter.py b/tests/test_unitaires/adapters/test_minio_storage_adapter.py new file mode 100644 index 0000000..4f11ef5 --- /dev/null +++ b/tests/test_unitaires/adapters/test_minio_storage_adapter.py @@ -0,0 +1,61 @@ +# Philippe Stepniewski +from unittest.mock import MagicMock, patch + +import pytest +from botocore.exceptions import ClientError + +from backend.infrastructure.minio_storage_adapter import BATCH_BUCKET, MinioStorageAdapter + + +@pytest.fixture +def adapter(): + with patch("backend.infrastructure.minio_storage_adapter.boto3") as mock_boto3: + mock_client = MagicMock() + mock_boto3.client.return_value = mock_client + with patch.dict("os.environ", {"MLFLOW_S3_ENDPOINT_URL": "http://minio:9000"}): + a = MinioStorageAdapter() + yield a, mock_client + + +def test_ensure_project_space_creates_bucket(adapter): + a, mock_client = adapter + mock_client.head_bucket.side_effect = ClientError({"Error": {"Code": "404"}}, "HeadBucket") + + a.ensure_project_space("my-project") + + mock_client.create_bucket.assert_called_once_with(Bucket=BATCH_BUCKET) + + +def test_ensure_project_space_existing_bucket(adapter): + a, mock_client = adapter + + a.ensure_project_space("my-project") + + mock_client.create_bucket.assert_not_called() + + +def test_remove_project_space_deletes_all_objects(adapter): + a, mock_client = adapter + paginator = MagicMock() + mock_client.get_paginator.return_value = paginator + paginator.paginate.return_value = [ + {"Contents": [{"Key": "my-project/.keep"}, {"Key": "my-project/data.csv"}]}, + ] + + a.remove_project_space("my-project") + + mock_client.delete_objects.assert_called_once_with( + Bucket=BATCH_BUCKET, + Delete={"Objects": [{"Key": "my-project/.keep"}, {"Key": "my-project/data.csv"}]}, + ) + + +def test_remove_project_space_no_objects(adapter): + a, mock_client = adapter + paginator = MagicMock() + mock_client.get_paginator.return_value = paginator + paginator.paginate.return_value = [{"Contents": []}] + + a.remove_project_space("my-project") + + mock_client.delete_objects.assert_not_called() diff --git a/tests/test_unitaires/use_cases/test_batch_predict.py b/tests/test_unitaires/use_cases/test_batch_predict.py new file mode 100644 index 0000000..cee9374 --- /dev/null +++ b/tests/test_unitaires/use_cases/test_batch_predict.py @@ -0,0 +1,236 @@ +# Philippe Stepniewski +import os +from unittest.mock import MagicMock, patch + +import pytest + +os.environ.setdefault("PATH_LOG_EVENTS", "/tmp/test_log_events") + +from backend.domain.entities.batch_prediction import BatchPrediction, BatchPredictionStatus +from backend.domain.entities.project import Project +from backend.domain.use_cases.batch_predict import ( + delete_batch_prediction, + download_batch_result, + get_batch_prediction_status, + list_batch_predictions, + submit_batch_prediction, +) + + +@pytest.fixture +def mock_batch_handler(): + return MagicMock() + + +@pytest.fixture +def mock_object_storage(): + return MagicMock() + + +@pytest.fixture +def mock_project_db_handler(): + handler = MagicMock() + handler.get_project.return_value = Project( + name="test-project", owner="owner", scope="scope", data_perimeter="perimeter", batch_enabled=True + ) + return handler + + +@pytest.fixture +def mock_registry(): + return MagicMock() + + +@pytest.fixture +def sample_batch_prediction(): + from datetime import datetime, timezone + + return BatchPrediction( + job_id="abc12345", + project_name="test-project", + model_name="my-model", + model_version="1", + status=BatchPredictionStatus.PENDING, + input_path="test-project/my-model/1/abc12345/input.csv", + output_path="test-project/my-model/1/abc12345/predictions-abc12345.csv", + created_at=datetime.now(timezone.utc), + ) + + +def test_submit_uploads_file_and_creates_job( + mock_batch_handler, mock_object_storage, mock_project_db_handler, sample_batch_prediction +): + mock_batch_handler.create_batch_job.return_value = sample_batch_prediction + + result = submit_batch_prediction( + project_name="test-project", + model_name="my-model", + version="1", + file_content=b"col1,col2\n1,2\n3,4", + job_id="abc12345", + object_storage=mock_object_storage, + batch_handler=mock_batch_handler, + project_db_handler=mock_project_db_handler, + ) + + mock_object_storage.upload_file.assert_called_once() + mock_batch_handler.create_batch_job.assert_called_once() + assert result["project_name"] == "test-project" + assert result["model_name"] == "my-model" + + +def test_submit_fails_if_batch_not_enabled(mock_batch_handler, mock_object_storage): + project_db = MagicMock() + project_db.get_project.return_value = Project( + name="test-project", owner="owner", scope="scope", data_perimeter="perimeter", batch_enabled=False + ) + + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc_info: + submit_batch_prediction( + project_name="test-project", + model_name="my-model", + version="1", + file_content=b"col1,col2\n1,2", + job_id="abc12345", + object_storage=mock_object_storage, + batch_handler=mock_batch_handler, + project_db_handler=project_db, + ) + assert exc_info.value.status_code == 400 + + +def test_get_status_delegates_to_handler(mock_batch_handler, sample_batch_prediction): + mock_batch_handler.get_job_status.return_value = sample_batch_prediction + + result = get_batch_prediction_status("test-project", "abc12345", mock_batch_handler) + + mock_batch_handler.get_job_status.assert_called_once_with("test-project", "abc12345") + assert result["job_id"] == "abc12345" + assert result["status"] == "pending" + + +def test_list_delegates_to_handler(mock_batch_handler, sample_batch_prediction): + mock_batch_handler.list_batch_jobs.return_value = [sample_batch_prediction] + + result = list_batch_predictions("test-project", mock_batch_handler) + + mock_batch_handler.list_batch_jobs.assert_called_once_with("test-project") + assert len(result) == 1 + assert result[0]["job_id"] == "abc12345" + + +def test_download_returns_file_content(mock_batch_handler, mock_object_storage, sample_batch_prediction): + mock_batch_handler.get_job_status.return_value = sample_batch_prediction + mock_object_storage.download_file.return_value = b"prediction\n0.95\n0.32" + + result = download_batch_result("test-project", "abc12345", mock_batch_handler, mock_object_storage) + + mock_object_storage.download_file.assert_called_once_with( + "test-project", "my-model/1/abc12345/predictions-abc12345.csv" + ) + assert result == b"prediction\n0.95\n0.32" + + +def test_delete_cleans_up_job_and_storage(mock_batch_handler, mock_object_storage, sample_batch_prediction): + mock_batch_handler.get_job_status.return_value = sample_batch_prediction + mock_batch_handler.delete_batch_job.return_value = True + mock_object_storage.list_files.return_value = [ + "my-model/1/abc12345/input.csv", + "my-model/1/abc12345/predictions-abc12345.csv", + ] + + result = delete_batch_prediction("test-project", "abc12345", mock_batch_handler, mock_object_storage) + + assert result is True + mock_batch_handler.delete_batch_job.assert_called_once_with("test-project", "abc12345") + assert mock_object_storage.delete_file.call_count == 2 + + +@patch("backend.domain.use_cases.batch_predict.build_model_docker_image") +@patch("backend.domain.use_cases.batch_predict.check_docker_image_exists") +def test_submit_builds_image_if_not_exists( + mock_check, + mock_build, + mock_batch_handler, + mock_object_storage, + mock_project_db_handler, + mock_registry, + sample_batch_prediction, +): + mock_check.return_value = False + mock_build.return_value = 1 + mock_batch_handler.create_batch_job.return_value = sample_batch_prediction + + submit_batch_prediction( + project_name="test-project", + model_name="my-model", + version="1", + file_content=b"col1,col2\n1,2", + job_id="abc12345", + object_storage=mock_object_storage, + batch_handler=mock_batch_handler, + project_db_handler=mock_project_db_handler, + registry=mock_registry, + ) + + mock_check.assert_called_once() + mock_build.assert_called_once_with(mock_registry, "test-project", "my-model", "1") + + +@patch("backend.domain.use_cases.batch_predict.build_model_docker_image") +@patch("backend.domain.use_cases.batch_predict.check_docker_image_exists") +def test_submit_skips_build_if_image_exists( + mock_check, + mock_build, + mock_batch_handler, + mock_object_storage, + mock_project_db_handler, + mock_registry, + sample_batch_prediction, +): + mock_check.return_value = True + mock_batch_handler.create_batch_job.return_value = sample_batch_prediction + + submit_batch_prediction( + project_name="test-project", + model_name="my-model", + version="1", + file_content=b"col1,col2\n1,2", + job_id="abc12345", + object_storage=mock_object_storage, + batch_handler=mock_batch_handler, + project_db_handler=mock_project_db_handler, + registry=mock_registry, + ) + + mock_check.assert_called_once() + mock_build.assert_not_called() + + +@patch("backend.domain.use_cases.batch_predict.build_model_docker_image") +@patch("backend.domain.use_cases.batch_predict.check_docker_image_exists") +def test_submit_fails_if_build_fails( + mock_check, mock_build, mock_batch_handler, mock_object_storage, mock_project_db_handler, mock_registry +): + mock_check.return_value = False + mock_build.return_value = 0 + + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc_info: + submit_batch_prediction( + project_name="test-project", + model_name="my-model", + version="1", + file_content=b"col1,col2\n1,2", + job_id="abc12345", + object_storage=mock_object_storage, + batch_handler=mock_batch_handler, + project_db_handler=mock_project_db_handler, + registry=mock_registry, + ) + + assert exc_info.value.status_code == 500 + assert "Failed to build model image" in exc_info.value.detail diff --git a/tests/test_unitaires/use_cases/test_projects_usecases.py b/tests/test_unitaires/use_cases/test_projects_usecases.py new file mode 100644 index 0000000..16c1e93 --- /dev/null +++ b/tests/test_unitaires/use_cases/test_projects_usecases.py @@ -0,0 +1,85 @@ +# Philippe Stepniewski +import os +from unittest.mock import MagicMock, patch + +import pytest + +# Set required env var before importing the module that needs it +os.environ.setdefault("PATH_LOG_EVENTS", "/tmp/test_log_events") + +from backend.domain.entities.project import Project +from backend.domain.use_cases.projects_usecases import add_project, remove_project, update_project_batch_enabled + + +@pytest.fixture +def mock_project_db_handler(): + handler = MagicMock() + handler.add_project.return_value = True + handler.remove_project.return_value = True + handler.update_batch_enabled.return_value = True + return handler + + +@pytest.fixture +def mock_object_storage(): + return MagicMock() + + +@patch("backend.domain.use_cases.projects_usecases.deploy_registry") +def test_add_project_with_batch_enabled_creates_storage_space( + mock_deploy_registry, mock_project_db_handler, mock_object_storage +): + project = Project(name="test-project", owner="owner", scope="scope", data_perimeter="perimeter", batch_enabled=True) + + add_project(mock_project_db_handler, project, mock_object_storage) + + mock_object_storage.ensure_project_space.assert_called_once_with("test-project") + mock_project_db_handler.add_project.assert_called_once_with(project) + + +@patch("backend.domain.use_cases.projects_usecases.deploy_registry") +def test_add_project_without_batch_does_not_create_storage_space( + mock_deploy_registry, mock_project_db_handler, mock_object_storage +): + project = Project( + name="test-project", owner="owner", scope="scope", data_perimeter="perimeter", batch_enabled=False + ) + + add_project(mock_project_db_handler, project, mock_object_storage) + + mock_object_storage.ensure_project_space.assert_not_called() + mock_project_db_handler.add_project.assert_called_once_with(project) + + +@patch("backend.domain.use_cases.projects_usecases._remove_project_namespace") +def test_remove_project_cleans_up_storage(mock_remove_ns, mock_project_db_handler, mock_object_storage): + remove_project(mock_project_db_handler, "test-project", mock_object_storage) + + mock_object_storage.remove_project_space.assert_called_once_with("test-project") + mock_project_db_handler.remove_project.assert_called_once_with("test-project") + + +@patch("backend.domain.use_cases.projects_usecases._remove_project_namespace") +def test_remove_project_continues_if_storage_cleanup_fails( + mock_remove_ns, mock_project_db_handler, mock_object_storage +): + mock_object_storage.remove_project_space.side_effect = Exception("Storage error") + + result = remove_project(mock_project_db_handler, "test-project", mock_object_storage) + + assert result is True + mock_project_db_handler.remove_project.assert_called_once_with("test-project") + + +def test_update_batch_enabled_to_true_creates_space(mock_project_db_handler, mock_object_storage): + update_project_batch_enabled(mock_project_db_handler, "test-project", True, mock_object_storage) + + mock_object_storage.ensure_project_space.assert_called_once_with("test-project") + mock_project_db_handler.update_batch_enabled.assert_called_once_with("test-project", True) + + +def test_update_batch_enabled_to_false_removes_space(mock_project_db_handler, mock_object_storage): + update_project_batch_enabled(mock_project_db_handler, "test-project", False, mock_object_storage) + + mock_object_storage.remove_project_space.assert_called_once_with("test-project") + mock_project_db_handler.update_batch_enabled.assert_called_once_with("test-project", False) diff --git a/tests/tests_end_to_end/test_batch_prediction_e2e.py b/tests/tests_end_to_end/test_batch_prediction_e2e.py new file mode 100644 index 0000000..3370cbc --- /dev/null +++ b/tests/tests_end_to_end/test_batch_prediction_e2e.py @@ -0,0 +1,281 @@ +# Philippe Stepniewski +""" +End-to-end test: Batch prediction workflow. + +This test covers: +1. Project creation with batch enabled +2. Model training and push to MLflow +3. Batch prediction submission (triggers auto-build of Docker image) +4. Status polling until completion +5. Result download and validation +6. Cleanup (delete job + files + project) +""" + +import os +import random +import string +import subprocess +import sys +import time + +import mlflow +import mlflow.sklearn +import pytest +import requests +from sklearn.datasets import load_iris +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import train_test_split + +from tests.conftest import MP_HOSTNAME, cleanup_project, login, run_cli + +_this_module = sys.modules[__name__] + +PROJECT_SUFFIX = "".join(random.choices(string.ascii_lowercase, k=6)) +PROJECT_NAME = f"e2ebatch{PROJECT_SUFFIX}" +MODEL_NAME = "batch_test_model" +MODEL_VERSION = "1" + + +def _setup_minikube_docker_env(): + try: + result = subprocess.run( + ["minikube", "docker-env", "--shell", "bash"], + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode == 0: + env_lines = [line.strip() for line in result.stdout.split("\n") if line.startswith("export ")] + for line in env_lines: + var_assignment = line.replace("export ", "", 1) + if "=" in var_assignment: + key, value = var_assignment.split("=", 1) + value = value.strip("\"'") + os.environ[key] = value + except Exception as exc: + print(f"[DEBUG] Error setting up minikube docker env: {exc}") + + +def _run_debug_cmd(label, cmd): + print(f"[DEBUG] {label}: {' '.join(cmd)}") + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=40) + print(f"[DEBUG] {label} exit={result.returncode}") + if result.stdout: + print(f"[DEBUG] {label} stdout:\n{result.stdout}") + if result.stderr: + print(f"[DEBUG] {label} stderr:\n{result.stderr}") + except Exception as exc: # noqa: BLE001 + print(f"[DEBUG] {label} error: {exc}") + + +def _dump_batch_debug_info(): + _run_debug_cmd("kubectl get jobs", ["kubectl", "get", "jobs", "-n", PROJECT_NAME, "-o", "wide"]) + _run_debug_cmd("kubectl get pods", ["kubectl", "get", "pods", "-n", PROJECT_NAME, "-o", "wide"]) + _run_debug_cmd( + "kubectl get events", + ["kubectl", "get", "events", "-n", PROJECT_NAME, "--sort-by=.metadata.creationTimestamp"], + ) + result = subprocess.run( + ["kubectl", "get", "pods", "-n", PROJECT_NAME, "-l", "app=batch-prediction", "--no-headers"], + capture_output=True, + text=True, + timeout=20, + ) + if result.returncode == 0 and result.stdout.strip(): + for line in result.stdout.strip().splitlines(): + pod_name = line.split()[0] + _run_debug_cmd(f"kubectl logs {pod_name}", ["kubectl", "logs", pod_name, "-n", PROJECT_NAME, "--tail=50"]) + + +@pytest.fixture(scope="module", autouse=True) +def setup_and_teardown(): + print(f"[DEBUG] Setting up batch e2e test with project {PROJECT_NAME}") + _setup_minikube_docker_env() + assert login() == 0, "Login failed" + + yield + + cleanup_project(PROJECT_NAME) + + +# ── Track readiness (set via module attribute for robustness) ──── +_mlflow_ready = False + + +def _skip_if_mlflow_not_ready(): + if not getattr(_this_module, "_mlflow_ready", False): + pytest.skip("MLflow registry not ready") + + +# ── Tests (ordered) ────────────────────────────────────────────── + + +@pytest.mark.order(1) +def test_create_project_with_batch(): + result = run_cli("projects", "add", "--name", PROJECT_NAME, "--batch-enabled") + assert result.returncode == 0, f"Project creation failed: {result.stderr}" + assert "Project created successfully" in result.stdout + + +@pytest.mark.order(2) +def test_mlflow_registry_responds(): + registry_url = f"http://{MP_HOSTNAME}/registry/{PROJECT_NAME}/" + deadline = time.time() + 300 + while time.time() < deadline: + result = subprocess.run( + ["curl", "-s", "-o", "/dev/null", "-w", "%{http_code}", registry_url], + capture_output=True, + text=True, + ) + if result.stdout == "200": + _this_module._mlflow_ready = True + return + print(f"[DEBUG] MLflow registry HTTP {result.stdout}, retrying in 5s...") + time.sleep(5) + assert False, "MLflow registry did not respond with 200 within 5 minutes" + + +@pytest.mark.order(3) +def test_train_and_push_model(): + _skip_if_mlflow_not_ready() + api_url = f"http://{MP_HOSTNAME}/registry/{PROJECT_NAME}/api/2.0/mlflow/experiments/search" + response = requests.post(api_url, json={"max_results": 1}, timeout=10) + assert response.status_code == 200, f"MLflow API not ready: {response.status_code}" + + data = load_iris() + x_train, _, y_train, _ = train_test_split(data.data, data.target, test_size=0.3, random_state=42) + + mlflow.set_tracking_uri(f"http://{MP_HOSTNAME}/registry/{PROJECT_NAME}/") + + with mlflow.start_run(): + model = RandomForestClassifier(n_estimators=2, random_state=42) + model.fit(x_train, y_train) + mlflow.sklearn.log_model(model, "custom_model", registered_model_name=MODEL_NAME) + + deadline = time.time() + 60 + while time.time() < deadline: + result = run_cli("projects", "list-models", PROJECT_NAME) + if result.returncode == 0 and MODEL_NAME in result.stdout: + return + time.sleep(5) + assert False, f"Model {MODEL_NAME} not found after 60s" + + +@pytest.mark.order(4) +def test_submit_batch_prediction(): + """Submit a batch prediction via the API. This triggers auto-build of the Docker image.""" + _skip_if_mlflow_not_ready() + + # Create a small CSV matching the iris model (4 features) + csv_content = "0,1,2,3\n5.1,3.5,1.4,0.2\n6.2,2.9,4.3,1.3\n7.1,3.0,5.9,2.1\n" + csv_path = f"/tmp/batch_e2e_test_{PROJECT_NAME}.csv" + with open(csv_path, "w") as f: + f.write(csv_content) + + result = run_cli("batch", "submit", PROJECT_NAME, MODEL_NAME, MODEL_VERSION, "--file-path", csv_path) + assert result.returncode == 0, f"Batch submit failed: {result.stderr}\n{result.stdout}" + assert "Job ID:" in result.stdout or "job_id" in result.stdout + + # Extract job_id from output + for line in result.stdout.splitlines(): + if "Job ID:" in line: + _this_module._batch_job_id = line.split("Job ID:")[-1].strip() + print(f"[DEBUG] Batch job submitted with ID: {_this_module._batch_job_id}") + return + + assert False, f"Could not extract job ID from output: {result.stdout}" + + +_batch_job_id = None + + +def _skip_if_no_job(): + if not getattr(_this_module, "_batch_job_id", None): + pytest.skip("No batch job ID available") + + +@pytest.mark.order(5) +def test_batch_job_completes(): + """Poll batch job status until it completes (building → pending → running → completed).""" + _skip_if_mlflow_not_ready() + _skip_if_no_job() + + deadline = time.time() + 600 # 10 minutes — includes Docker image build + last_status = None + while time.time() < deadline: + result = run_cli("batch", "status", PROJECT_NAME, _this_module._batch_job_id) + output = result.stdout.lower() + + if "completed" in output: + print(f"[DEBUG] Batch job {_this_module._batch_job_id} completed") + return + + if "failed" in output: + print(f"[DEBUG] Batch job {_this_module._batch_job_id} FAILED") + _dump_batch_debug_info() + assert False, f"Batch job failed: {result.stdout}" + + # Extract status for debug + for status_word in ["building", "pending", "running"]: + if status_word in output: + if status_word != last_status: + print(f"[DEBUG] Batch job status: {status_word}") + last_status = status_word + break + + time.sleep(10) + + _dump_batch_debug_info() + assert False, f"Batch job did not complete within 10 minutes. Last status: {last_status}" + + +@pytest.mark.order(6) +def test_download_batch_result(): + """Download the batch prediction result and validate it.""" + _skip_if_mlflow_not_ready() + _skip_if_no_job() + + output_path = f"/tmp/batch_e2e_result_{PROJECT_NAME}.csv" + result = run_cli("batch", "download", PROJECT_NAME, _this_module._batch_job_id, "--output", output_path) + assert result.returncode == 0, f"Download failed: {result.stderr}\n{result.stdout}" + assert "downloaded" in result.stdout.lower() + + assert os.path.exists(output_path), f"Output file not found at {output_path}" + with open(output_path) as f: + lines = f.readlines() + # Header + 3 prediction rows + assert len(lines) >= 4, f"Expected at least 4 lines (header + 3 rows), got {len(lines)}: {lines}" + assert "prediction" in lines[0].lower(), f"Expected 'prediction' header, got: {lines[0]}" + print(f"[DEBUG] Downloaded {len(lines) - 1} predictions") + + +@pytest.mark.order(7) +def test_list_batch_jobs(): + """Verify the job appears in the list.""" + _skip_if_mlflow_not_ready() + _skip_if_no_job() + + result = run_cli("batch", "list", PROJECT_NAME) + assert result.returncode == 0, f"List failed: {result.stderr}" + assert ( + _this_module._batch_job_id in result.stdout + ), f"Job {_this_module._batch_job_id} not in list output: {result.stdout}" + + +@pytest.mark.order(8) +def test_delete_batch_job(): + """Delete the batch job and associated files.""" + _skip_if_mlflow_not_ready() + _skip_if_no_job() + + result = run_cli("batch", "delete", PROJECT_NAME, _this_module._batch_job_id) + assert result.returncode == 0, f"Delete failed: {result.stderr}\n{result.stdout}" + assert "deleted" in result.stdout.lower() + + +@pytest.mark.order(9) +def test_delete_project(): + result = run_cli("projects", "delete", PROJECT_NAME) + assert result.returncode == 0, f"Delete failed: {result.stderr}" + assert "Project deleted successfully" in result.stdout diff --git a/tests/tests_end_to_end/test_from_project_creation_to_model_predict.py b/tests/tests_end_to_end/test_from_project_creation_to_model_predict.py index b3af7a5..e15857e 100644 --- a/tests/tests_end_to_end/test_from_project_creation_to_model_predict.py +++ b/tests/tests_end_to_end/test_from_project_creation_to_model_predict.py @@ -13,6 +13,7 @@ import random import string import subprocess +import sys import time import os @@ -27,6 +28,8 @@ from backend.utils import sanitize_ressource_name from tests.conftest import cleanup_project, login, MP_HOSTNAME, run_cli +_this_module = sys.modules[__name__] + # Use random suffix to avoid conflicts with previous test runs (db dropper jobs, etc.) PROJECT_SUFFIX = "".join(random.choices(string.ascii_lowercase, k=6)) PROJECT_NAME = f"e2e{PROJECT_SUFFIX}" @@ -59,6 +62,7 @@ def setup_and_teardown(): cleanup_project(PROJECT_NAME) +@pytest.mark.order(1) def test_health_endpoint_responds(): """Test that the platform health endpoint responds.""" result = subprocess.run( @@ -69,13 +73,15 @@ def test_health_endpoint_responds(): assert result.stdout == "200", f"Health endpoint did not respond with 200: {result.stdout}" +@pytest.mark.order(2) def test_project_creation(): """Test project creation.""" result = run_cli("projects", "add", "--name", PROJECT_NAME) assert result.returncode == 0, f"Project creation failed: {result.stderr}" - assert "✅ Project created successfully" in result.stdout + assert "Project created successfully" in result.stdout +@pytest.mark.order(3) def test_mlflow_registry_responds(): """Test that MLflow registry responds after project creation.""" registry_url = f"http://{MP_HOSTNAME}/registry/{PROJECT_NAME}/" @@ -87,8 +93,7 @@ def test_mlflow_registry_responds(): text=True, ) if result.stdout == "200": - global _mlflow_ready - _mlflow_ready = True + _this_module._mlflow_ready = True return print(f"[DEBUG] MLflow registry HTTP {result.stdout}, retrying in 5s...") time.sleep(5) @@ -96,12 +101,12 @@ def test_mlflow_registry_responds(): assert False, f"MLflow registry did not respond with 200 within 5 minutes" -# Flag to track if MLflow is ready +# Flag to track if MLflow is ready (set via module attribute for robustness) _mlflow_ready = False def _skip_if_mlflow_not_ready(): - if not _mlflow_ready: + if not getattr(_this_module, "_mlflow_ready", False): pytest.skip("MLflow registry not ready - skipping dependent test") @@ -223,6 +228,7 @@ def _dump_registry_status(): _run_debug_cmd("curl registry index", ["curl", "-v", "--max-time", "20", registry_url]) +@pytest.mark.order(4) def test_train_and_push_model_to_mlflow(): """Test model training and push to MLflow.""" _skip_if_mlflow_not_ready() @@ -251,6 +257,7 @@ def test_train_and_push_model_to_mlflow(): assert False, f"Model {MODEL_NAME} not found after 60s. Last output: {result.stdout}" +@pytest.mark.order(5) def test_deploy_model(): """Test model deployment.""" _skip_if_mlflow_not_ready() @@ -284,6 +291,7 @@ def test_deploy_model(): assert False, f"Deployment {deployment_name} was not created within 5 minutes" +@pytest.mark.order(6) def test_deployed_model_health_check(): """Test that deployed model responds to health check.""" _skip_if_mlflow_not_ready() @@ -334,6 +342,7 @@ def test_deployed_model_health_check(): raise AssertionError("Model health endpoint did not respond with 200 within 5 minutes") +@pytest.mark.order(7) def test_deployed_model_returns_predictions(): """Test that deployed model returns predictions.""" _skip_if_mlflow_not_ready() @@ -375,6 +384,7 @@ def test_deployed_model_returns_predictions(): raise AssertionError("Could not get prediction from model within 2 minutes") +@pytest.mark.order(8) def test_list_deployed_models(): """Test that list deployed models shows the deployed model.""" _skip_if_mlflow_not_ready() @@ -383,6 +393,7 @@ def test_list_deployed_models(): assert MODEL_NAME in result.stdout or MODEL_NAME.replace("_", "-") in result.stdout +@pytest.mark.order(9) def test_undeploy_model(): """Test model undeployment.""" _skip_if_mlflow_not_ready() @@ -390,6 +401,7 @@ def test_undeploy_model(): assert result.returncode == 0, f"Undeploy failed: {result.stderr}" +@pytest.mark.order(10) def test_undeployed_model_is_removed(): """Test that undeployed model no longer has a K8s deployment.""" _skip_if_mlflow_not_ready() @@ -409,13 +421,15 @@ def test_undeployed_model_is_removed(): assert False, f"Deployment {deployment_name} should have been deleted after 30s" +@pytest.mark.order(11) def test_delete_project(): """Test project deletion.""" result = run_cli("projects", "delete", PROJECT_NAME) assert result.returncode == 0, f"Delete failed: {result.stderr}" - assert "✅ Project deleted successfully" in result.stdout + assert "Project deleted successfully" in result.stdout +@pytest.mark.order(12) def test_project_registry_is_removed(): """Test that project registry is no longer accessible after deletion.""" # Poll until registry is gone (instead of flat 60s sleep) diff --git a/uv.lock b/uv.lock index ed2ae0f..dd155f7 100644 --- a/uv.lock +++ b/uv.lock @@ -2212,6 +2212,7 @@ dev = [ { name = "pytest-asyncio" }, { name = "pytest-cov" }, { name = "pytest-mock" }, + { name = "pytest-order" }, { name = "pytest-playwright" }, { name = "ruff" }, ] @@ -2263,6 +2264,7 @@ dev = [ { name = "pytest-asyncio", specifier = ">=0.25.3" }, { name = "pytest-cov", specifier = ">=4.1.0" }, { name = "pytest-mock", specifier = ">=3.14.0" }, + { name = "pytest-order", specifier = ">=1.2.0" }, { name = "pytest-playwright" }, { name = "ruff", specifier = ">=0.1.4" }, ] @@ -3458,6 +3460,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095, upload-time = "2025-09-16T16:37:25.734Z" }, ] +[[package]] +name = "pytest-order" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1d/66/02ae17461b14a52ce5a29ae2900156b9110d1de34721ccc16ccd79419876/pytest_order-1.3.0.tar.gz", hash = "sha256:51608fec3d3ee9c0adaea94daa124a5c4c1d2bb99b00269f098f414307f23dde", size = 47544, upload-time = "2024-08-22T12:29:54.512Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1b/73/59b038d1aafca89f8e9936eaa8ffa6bb6138d00459d13a32ce070be4f280/pytest_order-1.3.0-py3-none-any.whl", hash = "sha256:2cd562a21380345dd8d5774aa5fd38b7849b6ee7397ca5f6999bbe6e89f07f6e", size = 14609, upload-time = "2024-08-22T12:29:53.156Z" }, +] + [[package]] name = "pytest-playwright" version = "0.7.2"