From 0842d8de76533d6d166772a4330961476d2da0a8 Mon Sep 17 00:00:00 2001 From: ThomSerg Date: Fri, 30 Jan 2026 11:36:11 +0100 Subject: [PATCH 1/2] Cherry pick dataset files --- cpmpy/tools/dataset/__init__.py | 7 + cpmpy/tools/dataset/_base.py | 222 +++++++++++ cpmpy/tools/dataset/jsplib.py | 200 ++++++++++ cpmpy/tools/dataset/miplib.py | 101 +++++ cpmpy/tools/dataset/mse.py | 116 ++++++ cpmpy/tools/dataset/nurserostering.py | 523 ++++++++++++++++++++++++++ cpmpy/tools/dataset/opb.py | 166 ++++++++ cpmpy/tools/dataset/psplib.py | 100 +++++ cpmpy/tools/dataset/xcsp3.py | 119 ++++++ 9 files changed, 1554 insertions(+) create mode 100644 cpmpy/tools/dataset/__init__.py create mode 100644 cpmpy/tools/dataset/_base.py create mode 100644 cpmpy/tools/dataset/jsplib.py create mode 100644 cpmpy/tools/dataset/miplib.py create mode 100644 cpmpy/tools/dataset/mse.py create mode 100644 cpmpy/tools/dataset/nurserostering.py create mode 100644 cpmpy/tools/dataset/opb.py create mode 100644 cpmpy/tools/dataset/psplib.py create mode 100644 cpmpy/tools/dataset/xcsp3.py diff --git a/cpmpy/tools/dataset/__init__.py b/cpmpy/tools/dataset/__init__.py new file mode 100644 index 000000000..65fb041b8 --- /dev/null +++ b/cpmpy/tools/dataset/__init__.py @@ -0,0 +1,7 @@ +from .miplib import MIPLibDataset +from .jsplib import JSPLibDataset +from .psplib import PSPLibDataset +from .nurserostering import NurseRosteringDataset +from .xcsp3 import XCSP3Dataset +from .opb import OPBDataset +from .mse import MSEDataset diff --git a/cpmpy/tools/dataset/_base.py b/cpmpy/tools/dataset/_base.py new file mode 100644 index 000000000..caf0ffdcf --- /dev/null +++ b/cpmpy/tools/dataset/_base.py @@ -0,0 +1,222 @@ +""" +Dataset Base Class + +This module defines the abstract `_Dataset` class, which serves as the foundation +for loading and managing benchmark instance collections in CPMpy-based experiments. +It standardizes how datasets are stored, accessed, and optionally transformed. +""" + +from abc import ABC, abstractmethod +import os +import pathlib +import io +import tempfile +from typing import Any, Optional, Tuple +from urllib.error import URLError +from urllib.request import HTTPError, Request, urlopen + +def format_bytes(bytes_num): + """ + Format bytes into human-readable string (e.g., KB, MB, GB). + """ + for unit in ['bytes', 'KB', 'MB', 'GB', 'TB']: + if bytes_num < 1024.0: + return f"{bytes_num:.1f} {unit}" + bytes_num /= 1024.0 + +try: + from tqdm import tqdm +except ImportError: + tqdm = None + +class _Dataset(ABC): + """ + Abstract base class for PyTorch-style datasets of benchmarking instances. + + The `_Dataset` class provides a standardized interface for downloading and + accessing benchmark instances. This class should not be used on its own. + """ + + def __init__( + self, + dataset_dir: str = ".", + transform=None, target_transform=None, + download: bool = False, + extension:str=".txt", + **kwargs + ): + self.dataset_dir = pathlib.Path(dataset_dir) + self.transform = transform + self.target_transform = target_transform + self.extension = extension + + if not self.dataset_dir.exists(): + if not download: + raise ValueError(f"Dataset not found. Please set download=True to download the dataset.") + else: + self.download() + files = sorted(list(self.dataset_dir.rglob(f"*{self.extension}"))) + print(f"Finished downloading {len(files)} instances") + + files = sorted(list(self.dataset_dir.rglob(f"*{self.extension}"))) + if len(files) == 0: + raise ValueError(f"Cannot find any instances inside dataset {self.dataset_dir}. Is it a valid dataset? If so, please report on GitHub.") + + @abstractmethod + def category(self) -> dict: + """ + Labels to distinguish instances into categories matching to those of the dataset. + E.g. + - year + - track + """ + pass + + @abstractmethod + def download(self, *args, **kwargs): + """ + How the dataset should be downloaded. + """ + pass + + def open(self, instance) -> io.TextIOBase: + """ + How an instance file from the dataset should be opened. + Especially usefull when files come compressed and won't work with + python standard library's 'open', e.g. '.xz', '.lzma'. + """ + return open(instance, "r") + + def metadata(self, file) -> dict: + metadata = self.category() | { + 'dataset': self.name, + 'name': pathlib.Path(file).stem.replace(self.extension, ''), + 'path': file, + } + return metadata + + def __len__(self) -> int: + """Return the total number of instances.""" + return len(list(self.dataset_dir.rglob(f"*{self.extension}"))) + + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + + if index < 0 or index >= len(self): + raise IndexError("Index out of range") + + # Get all compressed XML files and sort for deterministic behavior + files = sorted(list(self.dataset_dir.rglob(f"*{self.extension}"))) + file_path = files[index] + filename = str(file_path) + + # Basic metadata about the instance + metadata = self.metadata(file=filename) + if self.target_transform: + metadata = self.target_transform(metadata) + + if self.transform: + # does not need to remain a filename... + filename = self.transform(filename) + + return filename, metadata + + @staticmethod + def _download_file(url: str, target: str, destination: Optional[str] = None, + desc: str = None, + chunk_size: int = 1024 * 1024) -> os.PathLike: + """ + Download a file from a URL with progress bar and speed information. + + This method provides a reusable download function with progress updates + similar to pip and uv, showing download progress, speed, and ETA. + + Arguments: + url (str): The URL to download from. + target (str): The target filename to download. + destination (str, optional): The destination path to save the file. + desc (str, optional): Description to show in the progress bar. + If None, uses the filename. + chunk_size (int): Size of each chunk for download in bytes (default=1MB). + + Returns: + str: The destination path where the downloaded file is saved. + """ + + if desc is None: + desc = target + + if destination is None: + temp_destination = tempfile.NamedTemporaryFile(delete=False) + else: + os.makedirs(os.path.dirname(destination), exist_ok=True) + + try: + req = Request(url + target) + with urlopen(req) as response: + total_size = int(response.headers.get('Content-Length', 0)) + + _Dataset._download_sequential(url + target, destination if destination is not None else temp_destination.name, total_size, desc, chunk_size) + + if destination is None: + temp_destination.close() + + return pathlib.Path(destination if destination is not None else temp_destination.name) + + except (HTTPError, URLError) as e: + raise ValueError(f"Failed to download file from {url + target}. Error: {str(e)}") + + @staticmethod + def _download_sequential(url: str, filepath: pathlib.Path, total_size: int, desc: str, + chunk_size: int = 1024 * 1024): + """Download file sequentially (fallback method).""" + import sys + + req = Request(url) + with urlopen(req) as response: + if tqdm is not None: + if total_size > 0: + with tqdm(total=total_size, unit='B', unit_scale=True, + unit_divisor=1024, desc=desc, file=sys.stdout, + miniters=1, dynamic_ncols=True, ascii=False) as pbar: + with open(filepath, 'wb') as f: + while True: + chunk = response.read(chunk_size) + if not chunk: + break + f.write(chunk) + pbar.update(len(chunk)) + else: + # Unknown size + with tqdm(unit='B', unit_scale=True, unit_divisor=1024, + desc=desc, file=sys.stdout, miniters=1, + dynamic_ncols=True, ascii=False) as pbar: + with open(filepath, 'wb') as f: + while True: + chunk = response.read(chunk_size) + if not chunk: + break + f.write(chunk) + pbar.update(len(chunk)) + else: + # Fallback to simple download if tqdm is not available + downloaded = 0 + with open(filepath, 'wb') as f: + while True: + chunk = response.read(chunk_size) + if not chunk: + break + f.write(chunk) + downloaded += len(chunk) + if total_size > 0: + percent = (downloaded / total_size) * 100 + sys.stdout.write(f"\r\033[KDownloading {desc}: {format_bytes(downloaded)}/{format_bytes(total_size)} ({percent:.1f}%)") + else: + sys.stdout.write(f"\r\033[KDownloading {desc}: {format_bytes(downloaded)}...") + sys.stdout.flush() + sys.stdout.write("\n") + sys.stdout.flush() + + + + diff --git a/cpmpy/tools/dataset/jsplib.py b/cpmpy/tools/dataset/jsplib.py new file mode 100644 index 000000000..7ce5b36a1 --- /dev/null +++ b/cpmpy/tools/dataset/jsplib.py @@ -0,0 +1,200 @@ +""" +PyTorch-style Dataset for Jobshop instances from JSPLib + +Simply create a dataset instance and start iterating over its contents: +The `metadata` contains usefull information about the current problem instance. + +https://github.com/tamy0612/JSPLIB +""" + +import io +import os +import json +import pathlib +from typing import Tuple, Any +import zipfile +import numpy as np + +import cpmpy as cp +from cpmpy.tools.dataset._base import _Dataset + + +class JSPLibDataset(_Dataset): # torch.utils.data.Dataset compatible + + """ + JSP Dataset in a PyTorch compatible format. + + More information on JSPLib can be found here: https://github.com/tamy0612/JSPLIB + """ + + name = "jsplib" + + def __init__(self, root: str = ".", transform=None, target_transform=None, download: bool = False): + """ + Initialize the PSPLib Dataset. + + Arguments: + root (str): Root directory containing the jsp instances (if 'download', instances will be downloaded to this location) + transform (callable, optional): Optional transform to be applied on the instance data + target_transform (callable, optional): Optional transform to be applied on the file path + download (bool): If True, downloads the dataset from the internet and puts it in `root` directory + """ + + self.root = pathlib.Path(root) + self.metadata_file = "instances.json" + + dataset_dir = self.root / self.name + + super().__init__( + dataset_dir=dataset_dir, + transform=transform, target_transform=target_transform, + download=download, extension="" + ) + + def category(self) -> dict: + return {} # no categories + + def download(self): + + url = "https://github.com/tamy0612/JSPLIB/archive/refs/heads/" # download full repo... + target = "master.zip" + target_download_path = self.root / target + + print(f"Downloading JSPLib instances from github.com/tamy0612/JSPLIB") + + try: + target_download_path = self._download_file(url, target, destination=str(target_download_path)) + except ValueError as e: + raise ValueError(f"No dataset available on {url}. Error: {str(e)}") + + # Extract files + with zipfile.ZipFile(target_download_path, 'r') as zip_ref: + self.dataset_dir.mkdir(parents=True, exist_ok=True) + + # Extract files + for file_info in zip_ref.infolist(): + if file_info.filename.startswith("JSPLIB-master/instances/") and file_info.file_size > 0: + filename = pathlib.Path(file_info.filename).name + with zip_ref.open(file_info) as source, open(self.dataset_dir / filename, 'wb') as target: + target.write(source.read()) + # extract metadata file + with zip_ref.open("JSPLIB-master/instances.json") as source, open(self.dataset_dir / self.metadata_file, 'wb') as target: + target.write(source.read()) + + # Clean up the zip file + target_download_path.unlink() + + + def __getitem__(self, index: int|str) -> Tuple[Any, Any]: + """ + Get a single JSPLib instance filename and metadata. + + Args: + index (int or str): Index or name of the instance to retrieve + + Returns: + Tuple[Any, Any]: A tuple containing: + - The filename of the instance + - Metadata dictionary with file name, track, year etc. + """ + if isinstance(index, int) and (index < 0 or index >= len(self)): + raise IndexError("Index out of range") + + # Get all instance files and sort for deterministic behavior # TODO: use natsort instead? + files = sorted(list(self.dataset_dir.rglob("*[!.json]"))) # exclude metadata file + if isinstance(index, int): + file_path = files[index] + elif isinstance(index, str): + for file_path in files: + if file_path.stem == index: + break + else: + raise IndexError(f"Instance {index} not found in dataset") + + filename = str(file_path) + if self.transform: + # does not need to remain a filename... + filename = self.transform(filename) + + with open(self.dataset_dir / self.metadata_file, "r") as f: + for entry in json.load(f): + if entry["name"] == file_path.stem: + metadata = entry + if "bounds" not in metadata: + metadata["bounds"] = {"upper": metadata["optimum"], "lower": metadata["optimum"]} + del metadata['path'] + metadata['path'] = str(file_path) + break + else: + metadata = dict() + + if self.target_transform: + metadata = self.target_transform(metadata) + + return filename, metadata + + def open(self, instance: os.PathLike) -> callable: + return open(instance, "r") + + +def parse_jsp(filename: str): + """ + Parse a JSPLib instance file + Returns two matrices: + - task to machines indicating on which machine to run which task + - task durations: indicating the duration of each task + """ + + with open(filename, "r") as f: + line = f.readline() + while line.startswith("#"): + line = f.readline() + n_jobs, n_tasks = map(int, line.strip().split(" ")) + matrix = np.fromstring(f.read(), sep=" ", dtype=int).reshape((n_jobs, n_tasks*2)) + + task_to_machines = np.empty(dtype=int, shape=(n_jobs, n_tasks)) + task_durations = np.empty(dtype=int, shape=(n_jobs, n_tasks)) + + for t in range(n_tasks): + task_to_machines[:, t] = matrix[:, t*2] + task_durations[:, t] = matrix[:, t*2+1] + + return task_to_machines, task_durations + + +def jobshop_model(task_to_machines, task_durations): + + """ + Create a CPMpy model for the Jobshop problem. + """ + + task_to_machines = np.array(task_to_machines) + dur = np.array(task_durations) + + assert task_to_machines.shape == task_durations.shape + + n_jobs, n_tasks = task_to_machines.shape + + start = cp.intvar(0, task_durations.sum(), name="start", shape=(n_jobs,n_tasks)) # extremely bad upperbound... TODO + end = cp.intvar(0, task_durations.sum(), name="end", shape=(n_jobs,n_tasks)) # extremely bad upperbound... TODO + makespan = cp.intvar(0, task_durations.sum(), name="makespan") # extremely bad upperbound... TODO + + model = cp.Model() + model += start + dur == end + model += end[:,:-1] <= start[:,1:] # precedences + + for machine in set(task_to_machines.flat): + model += cp.NoOverlap(start[task_to_machines == machine], + dur[task_to_machines == machine], + end[task_to_machines == machine]) + + model += end <= makespan + model.minimize(makespan) + + return model, (start, makespan) + + +if __name__ == "__main__": + dataset = JSPLibDataset(root=".", download=True, transform=parse_jsp) + print("Dataset size:", len(dataset)) + print("Instance 0:") diff --git a/cpmpy/tools/dataset/miplib.py b/cpmpy/tools/dataset/miplib.py new file mode 100644 index 000000000..f80634e28 --- /dev/null +++ b/cpmpy/tools/dataset/miplib.py @@ -0,0 +1,101 @@ +""" +MIPLib Dataset + +https://maxsat-evaluations.github.io/ +""" + + +import os +import gzip +import zipfile +import pathlib +import io + +from cpmpy.tools.dataset._base import _Dataset + + +class MIPLibDataset(_Dataset): # torch.utils.data.Dataset compatible + + """ + MIPLib Dataset in a PyTorch compatible format. + + More information on MIPLib can be found here: https://miplib.zib.de/ + """ + + name = "miplib" + + def __init__( + self, + root: str = ".", + year: int = 2024, track: str = "exact-unweighted", + transform=None, target_transform=None, + download: bool = False + ): + """ + Constructor for a dataset object of the MIPLib competition. + + Arguments: + root (str): Root directory where datasets are stored or will be downloaded to (default="."). + year (int): Year of the dataset to use (default=2024). + track (str): Track name specifying which subset of the dataset instances to load (default="exact-unweighted"). + transform (callable, optional): Optional transform applied to the instance file path. + target_transform (callable, optional): Optional transform applied to the metadata dictionary. + download (bool): If True, downloads the dataset if it does not exist locally (default=False). + + Raises: + ValueError: If the dataset directory does not exist and `download=False`, + or if the requested year/track combination is not available. + """ + + self.root = pathlib.Path(root) + self.year = year + self.track = track + + dataset_dir = self.root / self.name / str(year) / track + + super().__init__( + dataset_dir=dataset_dir, + transform=transform, target_transform=target_transform, + download=download, extension=".mps.gz" + ) + + def category(self) -> dict: + return { + "year": self.year, + "track": self.track + } + + def download(self): + + url = "https://miplib.zib.de/downloads/" + target = "collection.zip" + target_download_path = self.root / target + + print(f"Downloading MIPLib instances from miplib.zib.de") + + try: + target_download_path = self._download_file(url, target, destination=str(target_download_path)) + except ValueError as e: + raise ValueError(f"No dataset available on {url}. Error: {str(e)}") + + # Extract files + with zipfile.ZipFile(target_download_path, 'r') as zip_ref: + self.dataset_dir.mkdir(parents=True, exist_ok=True) + + # Extract files + for file_info in zip_ref.infolist(): + filename = pathlib.Path(file_info.filename).name + with zip_ref.open(file_info) as source, open(self.dataset_dir / filename, 'wb') as target: + target.write(source.read()) + + # Clean up the zip file + target_download_path.unlink() + + def open(self, instance: os.PathLike) -> io.TextIOBase: + return gzip.open(instance, "rt") if str(instance).endswith(".gz") else open(instance) + + +if __name__ == "__main__": + dataset = MIPLibDataset(download=True) + print("Dataset size:", len(dataset)) + print("Instance 0:", dataset[0]) diff --git a/cpmpy/tools/dataset/mse.py b/cpmpy/tools/dataset/mse.py new file mode 100644 index 000000000..dd1fcc163 --- /dev/null +++ b/cpmpy/tools/dataset/mse.py @@ -0,0 +1,116 @@ +""" +MaxSAT Evaluation (MSE) Dataset + +https://maxsat-evaluations.github.io/ +""" + + +import os +import lzma +import zipfile +import pathlib +import io + +from cpmpy.tools.dataset._base import _Dataset + + +class MSEDataset(_Dataset): # torch.utils.data.Dataset compatible + + """ + MaxSAT Evaluation (MSE) benchmark dataset. + + Provides access to benchmark instances from the MaxSAT Evaluation + competitions. Instances are grouped by `year` and `track` (e.g., + `"exact-unweighted"`, `"exact-weighted"`) and stored as `.wcnf.xz` files. + If the dataset is not available locally, it can be automatically + downloaded and extracted. + + More information on the competition can be found here: https://maxsat-evaluations.github.io/ + """ + + name = "mse" + + def __init__( + self, + root: str = ".", + year: int = 2024, track: str = "exact-unweighted", + transform=None, target_transform=None, + download: bool = False + ): + """ + Constructor for a dataset object of the MSE competition. + + Arguments: + root (str): Root directory where datasets are stored or will be downloaded to (default="."). + year (int): Competition year of the dataset to use (default=2024). + track (str): Track name specifying which subset of the competition instances to load (default="exact-unweighted"). + transform (callable, optional): Optional transform applied to the instance file path. + target_transform (callable, optional): Optional transform applied to the metadata dictionary. + download (bool): If True, downloads the dataset if it does not exist locally (default=False). + + + Raises: + ValueError: If the dataset directory does not exist and `download=False`, + or if the requested year/track combination is not available. + """ + + self.root = pathlib.Path(root) + self.year = year + self.track = track + + # Check requested dataset + if not str(year).startswith('20'): + raise ValueError("Year must start with '20'") + if not track: + raise ValueError("Track must be specified, e.g. OPT-LIN, DEC-LIN, ...") + + dataset_dir = self.root / self.name / str(year) / track + + super().__init__( + dataset_dir=dataset_dir, + transform=transform, target_transform=target_transform, + download=download, extension=".wcnf.xz" + ) + + def category(self) -> dict: + return { + "year": self.year, + "track": self.track + } + + def download(self): + + url = f"https://www.cs.helsinki.fi/group/coreo/MSE{self.year}-instances/" + target = f"mse{str(self.year)[2:]}-{self.track}.zip" + target_download_path = self.root / target + + print(f"Downloading MaxSAT Eval {self.year} {self.track} instances from cs.helsinki.fi") + + try: + target_download_path = self._download_file(url, target, destination=str(target_download_path)) + except ValueError as e: + raise ValueError(f"No dataset available for year {self.year} and track {self.track}. Error: {str(e)}") + + # Extract only the specific track folder from the tar + with zipfile.ZipFile(target_download_path, 'r') as zip_ref: + # Create track folder in root directory, parents=True ensures recursive creation + self.dataset_dir.mkdir(parents=True, exist_ok=True) + + # Extract files + for file_info in zip_ref.infolist(): + # Extract file to family_dir, removing main_folder/track prefix + filename = pathlib.Path(file_info.filename).name + with zip_ref.open(file_info) as source, open(self.dataset_dir / filename, 'wb') as target: + target.write(source.read()) + + # Clean up the zip file + target_download_path.unlink() + + def open(self, instance: os.PathLike) -> io.TextIOBase: + return lzma.open(instance, "rt") if str(instance).endswith(".xz") else open(instance) + + +if __name__ == "__main__": + dataset = MSEDataset(year=2024, track="exact-weighted", download=True) + print("Dataset size:", len(dataset)) + print("Instance 0:", dataset[0]) diff --git a/cpmpy/tools/dataset/nurserostering.py b/cpmpy/tools/dataset/nurserostering.py new file mode 100644 index 000000000..8c23d9a45 --- /dev/null +++ b/cpmpy/tools/dataset/nurserostering.py @@ -0,0 +1,523 @@ +""" +PyTorch-style Dataset for Nurserostering instances from schedulingbenchmarks.org + +Simply create a dataset instance and start iterating over its contents: +The `metadata` contains usefull information about the current problem instance. + +https://schedulingbenchmarks.org/nrp/ +""" + +import os +import pathlib +import zipfile +import re +import io + +import cpmpy as cp +from cpmpy.tools.dataset._base import _Dataset + +# Optional dependencies +try: + import pandas as pd + _HAS_PANDAS = True +except ImportError: + _HAS_PANDAS = False + +try: + from faker import Faker + _HAS_FAKER = True +except ImportError: + _HAS_FAKER = False + + +class NurseRosteringDataset(_Dataset): # torch.utils.data.Dataset compatible + + """ + Nurserostering Dataset in a PyTorch compatible format. + + More information on nurserostering instances can be found here: https://schedulingbenchmarks.org/nrp/ + """ + + name = "nurserostering" + + def __init__(self, root: str = ".", transform=None, target_transform=None, download: bool = False, sort_key=None): + """ + Initialize the Nurserostering Dataset. + + Arguments: + root (str): Root directory containing the nurserostering instances (if 'download', instances will be downloaded to this location) + transform (callable, optional): Optional transform to be applied on the instance data + target_transform (callable, optional): Optional transform to be applied on the file path + download (bool): If True, downloads the dataset from the internet and puts it in `root` directory + sort_key (callable, optional): Optional function to sort instance files. If None, uses Python's built-in sorted(). + For natural/numeric sorting, pass natsorted from natsort library. + Example: from natsort import natsorted; dataset = NurseRosteringDataset(..., sort_key=natsorted) + """ + + self.root = pathlib.Path(root) + self.sort_key = sorted if sort_key is None else sort_key + + dataset_dir = self.root / self.name + + super().__init__( + dataset_dir=dataset_dir, + transform=transform, target_transform=target_transform, + download=download, extension=".txt" + ) + + def category(self) -> dict: + return {} # no categories + + def download(self): + + url = "https://schedulingbenchmarks.org/nrp/data/" + target = "instances1_24.zip" # download full repo... + target_download_path = self.root / target + + print(f"Downloading Nurserostering instances from schedulingbenchmarks.org") + + try: + target_download_path = self._download_file(url, target, destination=str(target_download_path)) + except ValueError as e: + raise ValueError(f"No dataset available on {url}. Error: {str(e)}") + + # make directory and extract files + with zipfile.ZipFile(target_download_path, 'r') as zip_ref: + self.dataset_dir.mkdir(parents=True, exist_ok=True) + + # Extract files + for file_info in zip_ref.infolist(): + filename = pathlib.Path(file_info.filename).name + with zip_ref.open(file_info) as source, open(self.dataset_dir / filename, 'wb') as target: + target.write(source.read()) + + # Clean up the zip file + target_download_path.unlink() + + def open(self, instance: os.PathLike) -> io.TextIOBase: + return open(instance, "r") + + +def _tag_to_data(string, tag, skip_lines=0, datatype=None, names=None, dtype=None): + """ + Extract data from a tagged section in the input string. + + Args: + string: Input string containing tagged sections + tag: Tag name to search for (e.g., "SECTION_SHIFTS") + skip_lines: Number of lines to skip after the tag + datatype: Type hint for return value. If None, returns list of dicts (CSV rows). + If int, str, etc., returns that type parsed from first line. + names: Optional list of column names to rename headers to. If provided, must match + the number of columns or be shorter (extra columns will keep original names). + dtype: Optional dict mapping column names to data types for conversion. + Example: {'Length': int, 'ShiftID': str} + + Returns: + If datatype is None: list of dicts (CSV rows as dictionaries) + If datatype is int, str, etc.: parsed value from first line + """ + regex = rf'{tag}[\s\S]*?($|(?=\n\s*\n))' + match = re.search(regex, string) + + if not match: + return None + + lines = list(match.group().split("\n")[skip_lines+1:]) + if not lines: + return None + + # If datatype is a simple type (int, str, etc.), parse accordingly + if datatype is not None and datatype not in (list, dict): + if datatype is int or datatype is float: + # For numeric types, return first line + first_line = lines[0].strip() + return datatype(first_line) if first_line else None + elif datatype is str: + # For string type, return the whole data section + return "\n".join(lines).strip() + + # Parse header + headers = lines[0].split(",") + # Clean headers: remove # and strip whitespace, but keep exact names + headers = [h.replace("#", "").strip() for h in headers] + + # Rename columns if names provided + if names is not None: + for i, new_name in enumerate(names): + if i < len(headers): + headers[i] = new_name + + # Parse data rows + rows = [] + for line in lines[1:]: + if not line.strip(): + continue + values = line.split(",") + # Pad values if needed + while len(values) < len(headers): + values.append("") + row = {} + for i in range(len(headers)): + value = values[i].strip() if i < len(values) else "" + col_name = headers[i] + + # Apply type conversion if dtype specified + if dtype is not None and col_name in dtype: + target_type = dtype[col_name] + row[col_name] = target_type(value) if value else None + else: + row[col_name] = value + rows.append(row) + + return rows + +def parse_scheduling_period(filename: str): + """ + Parse a nurserostering instance file. + + Args: + filename: Path to the nurserostering instance file. + + Returns a dictionary with native Python data structures (lists of dicts). + Use to_dataframes() transform to convert to pandas DataFrames if needed. + Use add_fake_names() transform to add randomly generated names to staff. + """ + with open(filename, "r") as f: + string = f.read() + + # Parse scheduling horizon + horizon = int(_tag_to_data(string, "SECTION_HORIZON", skip_lines=2, datatype=int)) + + # Parse shifts - list of dicts with ShiftID as key + shifts_rows = _tag_to_data(string, "SECTION_SHIFTS", + names=["ShiftID", "Length", "cannot follow"], + dtype={'ShiftID': str, 'Length': int, 'cannot follow': str}) + shifts = {} + for row in shifts_rows: + cannot_follow_str = row.get("cannot follow") or "" + shifts[row["ShiftID"]] = { + "Length": row["Length"], + "cannot follow": [v.strip() for v in cannot_follow_str.split("|") if v.strip()] + } + + # Parse staff - list of dicts + staff = _tag_to_data(string, "SECTION_STAFF", + names=["ID", "MaxShifts", "MaxTotalMinutes", "MinTotalMinutes", "MaxConsecutiveShifts", "MinConsecutiveShifts", "MinConsecutiveDaysOff", "MaxWeekends"], + dtype={'MaxShifts': str, 'MaxTotalMinutes': int, 'MinTotalMinutes': int, 'MaxConsecutiveShifts': int, 'MinConsecutiveShifts': int, 'MinConsecutiveDaysOff': int, 'MaxWeekends': int}) + + # Process MaxShifts column - split by | and create max_shifts_* columns + for idx, nurse in enumerate(staff): + max_shifts_str = nurse.get("MaxShifts", "").strip() + if max_shifts_str: + max_shift_parts = max_shifts_str.split("|") + for part in max_shift_parts: + if "=" in part: + shift_id, max_val = part.split("=", 1) + shift_id = shift_id.strip() + max_val = max_val.strip() + if shift_id and max_val: + nurse[f"max_shifts_{shift_id}"] = int(max_val) + + # Parse days off - this section has variable columns (EmployeeID + N day indices) + # Parse as raw string since column count varies per row + days_off_raw = _tag_to_data(string, "SECTION_DAYS_OFF", datatype=str) + days_off = [] + if days_off_raw: + for line in days_off_raw.split("\n"): + line = line.strip() + if not line or line.startswith("#") or line.lower().startswith("employeeid"): + continue + # Parse CSV-style line (handles variable number of columns) + parts = line.split(",") + if len(parts) > 0: + employee_id = parts[0].strip() + # Remaining parts are day indices + for day_str in parts[1:]: + day_str = day_str.strip() + if day_str and day_str.isdigit(): + day_idx = int(day_str) + if 0 <= day_idx < horizon: + days_off.append({"EmployeeID": employee_id, "DayIndex": day_idx}) + + # Parse shift requests + shift_on = _tag_to_data(string, "SECTION_SHIFT_ON_REQUESTS", + names=["EmployeeID", "Day", "ShiftID", "Weight"], + dtype={'Weight': int, "Day": int, "ShiftID": str}) + shift_off = _tag_to_data(string, "SECTION_SHIFT_OFF_REQUESTS", + names=["EmployeeID", "Day", "ShiftID", "Weight"], + dtype={'Weight': int, "Day": int, "ShiftID": str}) + cover = _tag_to_data(string, "SECTION_COVER", + names=["Day", "ShiftID", "Requirement", "Weight for under", "Weight for over"], + dtype={'Day': int, 'ShiftID': str, 'Requirement': int, 'Weight for under': int, 'Weight for over': int}) + + return dict(horizon=horizon, shifts=shifts, staff=staff, days_off=days_off, + shift_on=shift_on, shift_off=shift_off, cover=cover) + + +def _add_fake_names(data, seed=0): + """ + Transform function to add randomly generated names to staff using Faker. + + This function can be used as a transform argument to NurseRosteringDataset + to add fake names to the parsed data. + + Example: + dataset = NurseRosteringDataset( + root=".", + transform=lambda fname: add_fake_names(parse_scheduling_period(fname)) + ) + + Or combine with other transforms: + dataset = NurseRosteringDataset( + root=".", + transform=lambda fname: to_dataframes( + add_fake_names(parse_scheduling_period(fname)) + ) + ) + + Args: + data: Dictionary returned by parse_scheduling_period() + seed: Random seed for reproducible name generation (default: 0) + + Returns: + Dictionary with 'name' field added to each staff member + + Raises: + ImportError: If Faker is not installed + """ + if not _HAS_FAKER: + raise ImportError("Faker is required for add_fake_names(). Install it with: pip install faker") + + fake = Faker() + fake.seed_instance(seed) + + # Add names to staff + for idx, nurse in enumerate(data["staff"]): + nurse["name"] = fake.unique.first_name() + + return data + + +def _to_dataframes(data): + """ + Transform function to convert native data structures to pandas DataFrames. + + This function can be used as a transform argument to NurseRosteringDataset + to convert the parsed data into pandas DataFrames for easier manipulation. + + Example: + dataset = NurseRosteringDataset( + root=".", + transform=lambda fname: to_dataframes(parse_scheduling_period(fname)) + ) + + Args: + data: Dictionary returned by parse_scheduling_period() + + Returns: + Dictionary with pandas DataFrames instead of native structures + + Raises: + ImportError: If pandas is not installed + """ + if not _HAS_PANDAS: + raise ImportError("pandas is required for to_dataframes(). Install it with: pip install pandas") + + result = {"horizon": data["horizon"]} + + # Convert shifts dict to DataFrame + shifts_rows = [] + for shift_id, shift_data in data["shifts"].items(): + row = {"ShiftID": shift_id, "Length": shift_data["Length"], + "cannot follow": "|".join(shift_data["cannot follow"])} + shifts_rows.append(row) + result["shifts"] = pd.DataFrame(shifts_rows).set_index("ShiftID") + + # Convert staff list to DataFrame + result["staff"] = pd.DataFrame(data["staff"]).set_index("ID") + + # Convert days_off list to DataFrame + result["days_off"] = pd.DataFrame(data["days_off"]) + + # Convert shift_on, shift_off, cover lists to DataFrames + result["shift_on"] = pd.DataFrame(data["shift_on"]) + result["shift_off"] = pd.DataFrame(data["shift_off"]) + result["cover"] = pd.DataFrame(data["cover"]) + + return result + + +def nurserostering_model(horizon, shifts, staff, days_off, shift_on, shift_off, cover): + """ + Create a CPMpy model for nurserostering. + + Args: + horizon: Number of days in the scheduling period + shifts: Dict mapping shift_id to dict with shift data + staff: List of dicts, each representing a nurse with their constraints + days_off: List of dicts with days off for each nurse + shift_on: List of dicts with shift-on requests for each nurse + shift_off: List of dicts with shift-off requests for each nurse + cover: List of dicts with cover requirements for each day and shift + """ + n_nurses = len(staff) + + FREE = 0 + shift_ids = list(shifts.keys()) + SHIFTS = ["F"] + shift_ids + + nurse_view = cp.intvar(0, len(shifts), shape=(n_nurses, horizon), name="nv") + + model = cp.Model() + + # Shifts which cannot follow the shift on the previous day. + for shift_id, shift_data in shifts.items(): + for other_shift in shift_data['cannot follow']: + model += (nurse_view[:,:-1] == SHIFTS.index(shift_id)).implies( + nurse_view[:,1:] != SHIFTS.index(other_shift)) + + # Maximum number of shifts of each type that can be assigned to each employee. + for i, nurse in enumerate(staff): + for shift_id in shift_ids: + max_shifts = nurse[f"max_shifts_{shift_id}"] + model += cp.Count(nurse_view[i], SHIFTS.index(shift_id)) <= max_shifts + + # Minimum and maximum amount of total time in minutes that can be assigned to each employee. + shift_length = cp.cpm_array([0] + [shifts[sid]['Length'] for sid in shift_ids]) # FREE = length 0 + for i, nurse in enumerate(staff): + time_worked = cp.sum(shift_length[nurse_view[i,d]] for d in range(horizon)) + model += time_worked <= nurse.get('MaxTotalMinutes') + model += time_worked >= nurse.get('MinTotalMinutes') + + # Maximum number of consecutive shifts that can be worked before having a day off. + for i, nurse in enumerate(staff): + max_days = nurse.get('MaxConsecutiveShifts') + for d in range(horizon - max_days): + window = nurse_view[i,d:d+max_days+1] + model += cp.Count(window, FREE) >= 1 # at least one holiday in this window + + # Minimum number of consecutive shifts that must be worked before having a day off. + for i, nurse in enumerate(staff): + min_days = nurse.get('MinConsecutiveShifts') + for d in range(1, horizon): + is_start_of_working_period = (nurse_view[i, d-1] == FREE) & (nurse_view[i, d] != FREE) + model += is_start_of_working_period.implies(cp.all(nurse_view[i,d:d+min_days] != FREE)) + + # Minimum number of consecutive days off. + for i, nurse in enumerate(staff): + min_days = nurse.get('MinConsecutiveDaysOff') + for d in range(1, horizon): + is_start_of_free_period = (nurse_view[i, d - 1] != FREE) & (nurse_view[i, d] == FREE) + model += is_start_of_free_period.implies(cp.all(nurse_view[i, d:d + min_days] == FREE)) + + # Max number of working weekends for each nurse + weekends = [(i - 1, i) for i in range(1, horizon) if (i + 1) % 7 == 0] + for i, nurse in enumerate(staff): + n_weekends = cp.sum((nurse_view[i,sat] != FREE) | (nurse_view[i,sun] != FREE) for sat,sun in weekends) + model += n_weekends <= nurse.get('MaxWeekends') + + # Days off + for holiday in days_off: + i = next((idx for idx, nurse in enumerate(staff) if nurse['ID'] == holiday['EmployeeID']), None) # index of employee + model += nurse_view[i,holiday['DayIndex']] == FREE + + # Shift requests, encode in linear objective + objective = 0 + for request in shift_on: + i = next((idx for idx, nurse in enumerate(staff) if nurse['ID'] == request['EmployeeID']), None) # index of employee + cpm_request = nurse_view[i, request['Day']] == SHIFTS.index(request['ShiftID']) + objective += request['Weight'] * ~cpm_request + + # Shift off requests, encode in linear objective + for request in shift_off: + i = next((idx for idx, nurse in enumerate(staff) if nurse['ID'] == request['EmployeeID']), None) # index of employee + cpm_request = nurse_view[i, request['Day']] != SHIFTS.index(request['ShiftID']) + objective += request['Weight'] * ~cpm_request + + # Cover constraints, encode in objective with slack variables + for cover_request in cover: + nb_nurses = cp.Count(nurse_view[:, cover_request['Day']], SHIFTS.index(cover_request['ShiftID'])) + slack_over, slack_under = cp.intvar(0, len(staff), shape=2) + model += nb_nurses - slack_over + slack_under == cover_request["Requirement"] + objective += cover_request["Weight for over"] * slack_over + cover_request["Weight for under"] * slack_under + + model.minimize(objective) + + return model, nurse_view + +if __name__ == "__main__": + dataset = NurseRosteringDataset(root=".", download=True, transform=parse_scheduling_period) + print("Dataset size:", len(dataset)) + + data, metadata = dataset[0] + print(data) + + model, nurse_view = nurserostering_model(**data) + assert model.solve() + + print(f"Found optimal solution with penalty of {model.objective_value()}") + assert model.objective_value() == 607 # optimal solution for the first instance + + # --- Pretty print solution without pandas --- + + horizon = data['horizon'] + shift_ids = list(data['shifts'].keys()) + names = ["-"] + shift_ids + sol = nurse_view.value() + + # Create table: rows are nurses + cover rows, columns are days + table = [] + row_labels = [] + + # Add nurse rows + for i, nurse in enumerate(data['staff']): + nurse_name = nurse.get('name', nurse.get('ID', f'Nurse_{i}')) + row_labels.append(nurse_name) + table.append([names[sol[i][d]] for d in range(horizon)]) + + # Add cover rows (initialize with empty strings) + for shift_id in shift_ids: + row_labels.append(f'Cover {shift_id}') + table.append([''] * horizon) + + # Fill in cover information + for cover_request in data['cover']: + shift = cover_request['ShiftID'] + day = cover_request['Day'] + requirement = cover_request['Requirement'] + # Count how many nurses are assigned to this shift on this day + num_shifts = sum(1 for i in range(len(data['staff'])) + if sol[i][day] == shift_ids.index(shift) + 1) # +1 because 0 is FREE + cover_row_idx = len(data['staff']) + shift_ids.index(shift) + table[cover_row_idx][day] = f"{num_shifts}/{requirement}" + + # Print table + days = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] + day_labels = [days[d % 7] for d in range(horizon)] + + # Calculate column widths + col_widths = [max(len(str(row[i])) for row in table + [day_labels]) for i in range(horizon)] + row_label_width = max(len(label) for label in row_labels) + + # Print header + print(f"\n{'Schedule:':<{row_label_width}}", end="") + for d, day_label in enumerate(day_labels): + print(f" {day_label:>{col_widths[d]}}", end="") + print() + + # Print separator + print("-" * (row_label_width + 1 + sum(w + 1 for w in col_widths))) + + # Print rows + for label, row in zip(row_labels, table): + print(f"{label:<{row_label_width}}", end="") + for d, val in enumerate(row): + print(f" {str(val):>{col_widths[d]}}", end="") + print() + + +if __name__ == "__main__": + dataset = NurseRosteringDataset(download=True) + print("Dataset size:", len(dataset)) + print("Instance 0:", dataset[0]) \ No newline at end of file diff --git a/cpmpy/tools/dataset/opb.py b/cpmpy/tools/dataset/opb.py new file mode 100644 index 000000000..2d0d876ba --- /dev/null +++ b/cpmpy/tools/dataset/opb.py @@ -0,0 +1,166 @@ +""" +Pseudo Boolean Competition (PB) Dataset + +https://www.cril.univ-artois.fr/PB25/ +""" + +import fnmatch +import lzma +import os +import pathlib +import tarfile +import io + +from cpmpy.tools.dataset._base import _Dataset + + +class OPBDataset(_Dataset): + """ + Pseudo Boolean Competition (PB) benchmark dataset. + + Provides access to benchmark instances from the Pseudo Boolean + competitions. Instances are grouped by `year` and `track` (e.g., + `"OPT-LIN"`, `"DEC-LIN"`) and stored as `.opb.xz` files. + If the dataset is not available locally, it can be automatically + downloaded and extracted. + + More information on the competition can be found here: https://www.cril.univ-artois.fr/PB25/ + """ + + name = "opb" + + def __init__( + self, + root: str = ".", + year: int = 2024, track: str = "OPT-LIN", + competition: bool = True, + transform=None, target_transform=None, + download: bool = False + ): + """ + Constructor for a dataset object of the PB competition. + + Arguments: + root (str): Root directory where datasets are stored or will be downloaded to (default="."). + year (int): Competition year of the dataset to use (default=2024). + track (str): Track name specifying which subset of the competition instances to load (default="OPT-LIN"). + competition (bool): If True, the dataset will filtered on competition-used instances. + transform (callable, optional): Optional transform applied to the instance file path. + target_transform (callable, optional): Optional transform applied to the metadata dictionary. + download (bool): If True, downloads the dataset if it does not exist locally (default=False). + + + Raises: + ValueError: If the dataset directory does not exist and `download=False`, + or if the requested year/track combination is not available. + """ + + self.root = pathlib.Path(root) + self.year = year + self.track = track + self.competition = competition + + # Check requested dataset + if not str(year).startswith('20'): + raise ValueError("Year must start with '20'") + if not track: + raise ValueError("Track must be specified, e.g. exact-weighted, exact-unweighted, ...") + + dataset_dir = self.root / self.name / str(year) / track / ('selected' if self.competition else 'normalized') + + super().__init__( + dataset_dir=dataset_dir, + transform=transform, target_transform=target_transform, + download=download, extension=".opb.xz" + ) + + def category(self) -> dict: + return { + "year": self.year, + "track": self.track + } + + def metadata(self, file) -> dict: + # Add the author to the metadata + return super().metadata(file) | {'author': str(file).split(os.sep)[-1].split("_")[0],} + + def download(self): + + url = "https://www.cril.univ-artois.fr/PB24/benchs/" + target = f"{'normalized' if not self.competition else 'selected'}-PB{str(self.year)[2:]}.tar" + target_download_path = self.root / target + + print(f"Downloading OPB {self.year} {self.track} {'competition' if self.competition else 'non-competition'} instances from www.cril.univ-artois.fr") + + try: + target_download_path = self._download_file(url, target, destination=str(target_download_path)) + except ValueError as e: + raise ValueError(f"No dataset available for year {self.year}. Error: {str(e)}") + + # Extract only the specific track folder from the tar + with tarfile.open(target_download_path, "r:*") as tar_ref: # r:* handles .tar, .tar.gz, .tar.bz2, etc. + # Get the main folder name + main_folder = None + for name in tar_ref.getnames(): + if "/" in name: + main_folder = name.split("/")[0] + break + + if main_folder is None: + raise ValueError(f"Could not find main folder in tar file") + + # Extract only files from the specified track + # Get all unique track names from tar + if not self.competition: + tracks = set() + for member in tar_ref.getmembers(): + parts = member.name.split("/") + if len(parts) > 2 and parts[0] == main_folder: + tracks.add(parts[1]) + else: + tracks = set() + for member in tar_ref.getmembers(): + parts = member.name.split("/") + if len(parts) > 2 and parts[0] == main_folder: + tracks.add(parts[2]) + + # Check if requested track exists + if self.track not in tracks: + raise ValueError(f"Track '{self.track}' not found in dataset. Available tracks: {sorted(tracks)}") + + # Create track folder in root directory + self.dataset_dir.mkdir(parents=True, exist_ok=True) + + # Extract files for the specified track + if not self.competition: + prefix = f"{main_folder}/{self.track}/" + else: + prefix = f"{main_folder}/*/{self.track}/" + for member in tar_ref.getmembers(): + if fnmatch.fnmatch(member.name, prefix + "*") and member.isfile(): + # Path relative to main_folder/track + # Find where the track folder ends and get everything after + track_marker = f"/{self.track}/" + marker_pos = member.name.find(track_marker) + relative_path = member.name[marker_pos + len(track_marker):] + + # Flatten: replace "/" with "_" to encode subfolders (some instances have clashing names) + flat_name = relative_path#.replace("/", "_") + target_path = self.dataset_dir / flat_name + + os.makedirs(os.path.dirname(target_path), exist_ok=True) + + with tar_ref.extractfile(member) as source, open(target_path, "wb") as target: + target.write(source.read()) + + # Clean up the tar file + target_download_path.unlink() + + def open(self, instance: os.PathLike) -> io.TextIOBase: + return lzma.open(instance, 'rt') if str(instance).endswith(".xz") else open(instance) + + +if __name__ == "__main__": + dataset = OPBDataset(year=2024, track="DEC-LIN", competition=True, download=True) + print("Dataset size:", len(dataset)) + print("Instance 0:", dataset[0]) diff --git a/cpmpy/tools/dataset/psplib.py b/cpmpy/tools/dataset/psplib.py new file mode 100644 index 000000000..ac685976a --- /dev/null +++ b/cpmpy/tools/dataset/psplib.py @@ -0,0 +1,100 @@ +""" +PSPlib Dataset + +https://www.om-db.wi.tum.de/psplib/getdata_sm.html +""" + +import os +import pathlib +import io +import zipfile + +from cpmpy.tools.dataset._base import _Dataset + +class PSPLibDataset(_Dataset): # torch.utils.data.Dataset compatible + """ + PSPlib Dataset in a PyTorch compatible format. + + More information on PSPlib can be found here: https://www.om-db.wi.tum.de/psplib/main.html + """ + + name = "psplib" + + def __init__(self, root: str = ".", variant: str = "rcpsp", family: str = "j30", transform=None, target_transform=None, download: bool = False): + """ + Constructor for a dataset object for PSPlib. + + Arguments: + root (str): Root directory containing the psplib instances (if 'download', instances will be downloaded to this location) + variant (str): scheduling variant (only 'rcpsp' is supported for now) + family (str): family name (e.g. j30, j60, etc...) + transform (callable, optional): Optional transform to be applied on the instance data + target_transform (callable, optional): Optional transform to be applied on the file path + download (bool): If True, downloads the dataset from the internet and puts it in `root` directory + + + Raises: + ValueError: If the dataset directory does not exist and `download=False`, + or if the requested variant/family combination is not available. + """ + + self.root = pathlib.Path(root) + self.variant = variant + self.family = family + + self.families = dict( + rcpsp = ["j30", "j60", "j90", "j120"] + ) + self.family_codes = dict(rcpsp="sm", mrcpsp="mm") + + if variant != "rcpsp": + raise ValueError("Only 'rcpsp' variant is supported for now") + if family not in self.families[variant]: + raise ValueError(f"Unknown problem family. Must be any of {','.join(self.families[variant])}") + + dataset_dir = self.root / self.name / self.variant / self.family + + super().__init__( + dataset_dir=dataset_dir, + transform=transform, target_transform=target_transform, + download=download, extension=f".{self.family_codes[self.variant]}" + ) + + def category(self) -> dict: + return { + "variant": self.variant, + "family": self.family + } + + def download(self): + + url = "https://www.om-db.wi.tum.de/psplib/files/" + target = f"{self.family}.{self.family_codes[self.variant]}.zip" + target_download_path = self.root / target + + print(f"Downloading PSPLib {self.variant} {self.family} instances from www.om-db.wi.tum.de") + + try: + target_download_path = self._download_file(url, target, destination=str(target_download_path)) + except ValueError as e: + raise ValueError(f"No dataset available for variant {self.variant} and family {self.family}. Error: {str(e)}") + + # make directory and extract files + with zipfile.ZipFile(target_download_path, 'r') as zip_ref: + # Create track folder in root directory, parents=True ensures recursive creation + self.dataset_dir.mkdir(parents=True, exist_ok=True) + + # Extract files + for file_info in zip_ref.infolist(): + # Extract file to family_dir, removing main_folder/track prefix + filename = pathlib.Path(file_info.filename).name + with zip_ref.open(file_info) as source, open(self.dataset_dir / filename, 'wb') as target: + target.write(source.read()) + # Clean up the zip file + target_download_path.unlink() + + +if __name__ == "__main__": + dataset = PSPLibDataset(variant="rcpsp", family="j30", download=True) + print("Dataset size:", len(dataset)) + print("Instance 0:", dataset[0]) \ No newline at end of file diff --git a/cpmpy/tools/dataset/xcsp3.py b/cpmpy/tools/dataset/xcsp3.py new file mode 100644 index 000000000..2fd989da7 --- /dev/null +++ b/cpmpy/tools/dataset/xcsp3.py @@ -0,0 +1,119 @@ +""" +XCS3 Dataset + +https://xcsp.org/instances/ +""" + +import os +import lzma +import zipfile +import pathlib +import io + +from cpmpy.tools.dataset._base import _Dataset + + +class XCSP3Dataset(_Dataset): # torch.utils.data.Dataset compatible + + """ + XCSP3 Dataset in a PyTorch compatible format. + + Arguments: + root (str): Root directory containing the XCSP3 instances (if 'download', instances will be downloaded to this location) + year (int): Competition year (2022, 2023 or 2024) + track (str, optional): Filter instances by track type (e.g., "COP", "CSP", "MiniCOP") + transform (callable, optional): Optional transform to be applied on the instance data (the file path of each problem instance) + target_transform (callable, optional): Optional transform to be applied on the metadata (the metadata dictionary of each problem instance) + download (bool): If True, downloads the dataset from the internet and puts it in `root` directory + """ + + name = "xcsp3" + + def __init__(self, root: str = ".", year: int = 2024, track: str = "CSP", transform=None, target_transform=None, download: bool = False): + """ + Initialize the XCSP3 Dataset. + """ + + self.root = pathlib.Path(root) + self.year = year + self.track = track + + dataset_dir = self.root / self.name / str(year) / track + + if not str(year).startswith('20'): + raise ValueError("Year must start with '20'") + if not track: + raise ValueError("Track must be specified, e.g. COP, CSP, MiniCOP, ...") + + super().__init__( + dataset_dir=dataset_dir, + transform=transform, target_transform=target_transform, + download=download, extension=".xml.lzma" + ) + + def category(self) -> dict: + return { + "year": self.year, + "track": self.track + } + + def download(self): + + url = "https://www.cril.univ-artois.fr/~lecoutre/compets/" + target = f"instancesXCSP{str(self.year)[2:]}.zip" + target_download_path = self.root / target + + print(f"Downloading XCSP3 {self.year} instances from www.cril.univ-artois.fr") + + try: + target_download_path = self._download_file(url, target, destination=str(target_download_path)) + except ValueError as e: + raise ValueError(f"No dataset available for year {self.year}. Error: {str(e)}") + + # Extract only the specific track folder from the zip + with zipfile.ZipFile(target_download_path, 'r') as zip_ref: + # Get the main folder name (e.g., "024_V3") + main_folder = None + for name in zip_ref.namelist(): + if '/' in name: + main_folder = name.split('/')[0] + break + + if main_folder is None: + raise ValueError(f"Could not find main folder in zip file") + + # Extract only files from the specified track + # Get all unique track names from zip + tracks = set() + for file_info in zip_ref.infolist(): + parts = file_info.filename.split('/') + if len(parts) > 2 and parts[0] == main_folder: + tracks.add(parts[1]) + + # Check if requested track exists + if self.track not in tracks: + raise ValueError(f"Track '{self.track}' not found in dataset. Available tracks: {sorted(tracks)}") + + # Create track folder in root directory, parents=True ensures recursive creation + self.dataset_dir.mkdir(parents=True, exist_ok=True) + + # Extract files for the specified track + prefix = f"{main_folder}/{self.track}/" + for file_info in zip_ref.infolist(): + if file_info.filename.startswith(prefix): + # Extract file to track_dir, removing main_folder/track prefix + filename = pathlib.Path(file_info.filename).name + with zip_ref.open(file_info) as source, open(self.dataset_dir / filename, 'wb') as target: + target.write(source.read()) + + # Clean up the zip file + target_download_path.unlink() + + def open(self, instance: os.PathLike) -> io.TextIOBase: + return lzma.open(instance, mode='rt', encoding='utf-8') if str(instance).endswith(".lzma") else open(instance) + + +if __name__ == "__main__": + dataset = XCSP3Dataset(year=2024, track="MiniCOP", download=True) + print("Dataset size:", len(dataset)) + print("Instance 0:", dataset[0]) From 864c48cebb8eb21d138948f2eca0b40f3deebee5 Mon Sep 17 00:00:00 2001 From: ThomSerg Date: Fri, 30 Jan 2026 12:50:51 +0100 Subject: [PATCH 2/2] update download docstring --- cpmpy/tools/dataset/_base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cpmpy/tools/dataset/_base.py b/cpmpy/tools/dataset/_base.py index caf0ffdcf..b8c852b1a 100644 --- a/cpmpy/tools/dataset/_base.py +++ b/cpmpy/tools/dataset/_base.py @@ -169,7 +169,10 @@ def _download_file(url: str, target: str, destination: Optional[str] = None, @staticmethod def _download_sequential(url: str, filepath: pathlib.Path, total_size: int, desc: str, chunk_size: int = 1024 * 1024): - """Download file sequentially (fallback method).""" + """ + Download file sequentially with progress bar. + """ + import sys req = Request(url)