diff --git a/test/unittests/test_managers/test_data_manager.py b/test/unittests/test_managers/test_data_manager.py index 6ebc353..5a2fcd1 100644 --- a/test/unittests/test_managers/test_data_manager.py +++ b/test/unittests/test_managers/test_data_manager.py @@ -1,7 +1,62 @@ +import threading + import pytest from unittest.mock import MagicMock, patch from weaviate_cli.managers.data_manager import DataManager import weaviate.classes.config as wvc +from weaviate.collections.classes.tenants import TenantActivityStatus + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _setup_mock_client_with_col(mock_client, col): + """Attach col to mock_client.collections with correct mock setup.""" + mock_collections = MagicMock() + mock_client.collections = mock_collections + mock_collections.exists.return_value = True + mock_collections.get.return_value = col + + +def _make_mt_col(tenant_names): + """Return a mock collection with MT enabled and the given tenant names.""" + col = MagicMock() + col.name = "TestCollection" + tenants_dict = {name: MagicMock() for name in tenant_names} + col.config.get.return_value = MagicMock( + multi_tenancy_config=MagicMock( + enabled=True, + auto_tenant_creation=False, + auto_tenant_activation=False, + ) + ) + col.tenants.get.return_value = tenants_dict + # with_tenant returns a fresh mock per tenant + col.with_tenant.side_effect = lambda name: MagicMock(name=f"col__{name}") + col.__len__ = MagicMock(return_value=0) + return col + + +def _make_non_mt_col(): + """Return a mock collection with MT disabled.""" + col = MagicMock() + col.name = "TestCollection" + col.config.get.return_value = MagicMock( + multi_tenancy_config=MagicMock( + enabled=False, + auto_tenant_creation=False, + auto_tenant_activation=False, + ) + ) + col.__len__ = MagicMock(return_value=0) + return col + + +# --------------------------------------------------------------------------- +# Existing smoke tests (kept for backwards compat) +# --------------------------------------------------------------------------- # --------------------------------------------------------------------------- @@ -348,7 +403,6 @@ def test_ingest_data(mock_client): multi_tenancy_config=MagicMock(auto_tenant_creation=True) ) - # Test data ingestion manager.create_data( collection="TestCollection", limit=100, @@ -368,7 +422,6 @@ def test_update_data(mock_client): mock_collection = MagicMock() mock_client.collections.get.return_value = mock_collection - # Test data update manager.update_data( collection="TestCollection", limit=100, @@ -387,10 +440,423 @@ def test_delete_data(mock_client): mock_collection = MagicMock() mock_client.collections.get.return_value = mock_collection - # Test data deletion manager.delete_data( collection="TestCollection", limit=100, ) mock_client.collections.get.assert_called_once_with("TestCollection") + + +# --------------------------------------------------------------------------- +# update_data – parallel tenant processing +# --------------------------------------------------------------------------- + + +class TestUpdateDataParallel: + def test_all_tenants_processed_in_parallel(self, mock_client): + """All tenants are processed when parallel_workers > 1.""" + manager = DataManager(mock_client) + col = _make_mt_col(["Tenant-0", "Tenant-1", "Tenant-2"]) + _setup_mock_client_with_col(mock_client, col) + + processed = [] + lock = threading.Lock() + + def fake_update(collection, *args, **kwargs): + with lock: + processed.append(collection) + return 5 + + with patch.object( + manager, "_DataManager__update_data", side_effect=fake_update + ): + manager.update_data( + collection="TestCollection", + limit=5, + parallel_workers=4, + ) + + assert len(processed) == 3 + + def test_sequential_when_parallel_workers_is_1(self, mock_client): + """When parallel_workers=1, all tenants are still processed (sequentially).""" + manager = DataManager(mock_client) + col = _make_mt_col(["Tenant-0", "Tenant-1", "Tenant-2"]) + _setup_mock_client_with_col(mock_client, col) + + processed = [] + + def fake_update(collection, *args, **kwargs): + processed.append(collection) + return 3 + + with patch.object( + manager, "_DataManager__update_data", side_effect=fake_update + ): + manager.update_data( + collection="TestCollection", + limit=3, + parallel_workers=1, + ) + + assert len(processed) == 3 + + def test_parallel_errors_collected_and_raised(self, mock_client): + """Errors from parallel tenant updates are collected and raised together.""" + manager = DataManager(mock_client) + col = _make_mt_col(["Tenant-0", "Tenant-1"]) + _setup_mock_client_with_col(mock_client, col) + + def fake_update(collection, *args, **kwargs): + raise Exception("simulated update error") + + with patch.object( + manager, "_DataManager__update_data", side_effect=fake_update + ): + with pytest.raises(Exception, match="Errors during parallel data update"): + manager.update_data( + collection="TestCollection", + limit=5, + parallel_workers=4, + ) + + def test_non_mt_collection_unaffected(self, mock_client): + """Non-MT collections are processed the same regardless of parallel_workers.""" + manager = DataManager(mock_client) + col = _make_non_mt_col() + _setup_mock_client_with_col(mock_client, col) + # Simulate MT-disabled exception from tenants.get() + col.tenants.get.side_effect = Exception("multi-tenancy is not enabled") + + processed = [] + + def fake_update(collection, *args, **kwargs): + processed.append(collection) + return 10 + + with patch.object( + manager, "_DataManager__update_data", side_effect=fake_update + ): + manager.update_data( + collection="TestCollection", + limit=10, + parallel_workers=4, + ) + + # Single "None" pseudo-tenant processed + assert len(processed) == 1 + + +# --------------------------------------------------------------------------- +# delete_data – parallel tenant processing +# --------------------------------------------------------------------------- + + +class TestDeleteDataParallel: + def test_all_tenants_processed_in_parallel(self, mock_client): + """All tenants are processed when parallel_workers > 1.""" + manager = DataManager(mock_client) + col = _make_mt_col(["Tenant-0", "Tenant-1", "Tenant-2"]) + _setup_mock_client_with_col(mock_client, col) + + processed = [] + lock = threading.Lock() + + def fake_delete(collection, *args, **kwargs): + with lock: + processed.append(collection) + return 5 + + with patch.object( + manager, "_DataManager__delete_data", side_effect=fake_delete + ): + manager.delete_data( + collection="TestCollection", + limit=5, + parallel_workers=4, + ) + + assert len(processed) == 3 + + def test_sequential_when_parallel_workers_is_1(self, mock_client): + """When parallel_workers=1, all tenants are still processed (sequentially).""" + manager = DataManager(mock_client) + col = _make_mt_col(["Tenant-0", "Tenant-1"]) + _setup_mock_client_with_col(mock_client, col) + + processed = [] + + def fake_delete(collection, *args, **kwargs): + processed.append(collection) + return 3 + + with patch.object( + manager, "_DataManager__delete_data", side_effect=fake_delete + ): + manager.delete_data( + collection="TestCollection", + limit=3, + parallel_workers=1, + ) + + assert len(processed) == 2 + + def test_parallel_errors_collected_and_raised(self, mock_client): + """Errors from parallel tenant deletions are collected and raised together.""" + manager = DataManager(mock_client) + col = _make_mt_col(["Tenant-0", "Tenant-1"]) + _setup_mock_client_with_col(mock_client, col) + + def fake_delete(collection, *args, **kwargs): + raise Exception("simulated delete error") + + with patch.object( + manager, "_DataManager__delete_data", side_effect=fake_delete + ): + with pytest.raises(Exception, match="Errors during parallel data deletion"): + manager.delete_data( + collection="TestCollection", + limit=5, + parallel_workers=4, + ) + + def test_specific_tenants_list_processed_in_parallel(self, mock_client): + """When tenants_list is provided, only those tenants are processed.""" + manager = DataManager(mock_client) + # Collection has 5 tenants but we only target 2 + col = _make_mt_col([f"Tenant-{i}" for i in range(5)]) + _setup_mock_client_with_col(mock_client, col) + + processed = [] + lock = threading.Lock() + + def fake_delete(collection, *args, **kwargs): + with lock: + processed.append(collection) + return 1 + + with patch.object( + manager, "_DataManager__delete_data", side_effect=fake_delete + ): + manager.delete_data( + collection="TestCollection", + limit=1, + tenants_list=["Tenant-0", "Tenant-2"], + parallel_workers=4, + ) + + assert len(processed) == 2 + + def test_non_mt_collection_unaffected(self, mock_client): + """Non-MT collections are processed without parallelism.""" + manager = DataManager(mock_client) + col = _make_non_mt_col() + _setup_mock_client_with_col(mock_client, col) + + processed = [] + + def fake_delete(collection, *args, **kwargs): + processed.append(collection) + return 5 + + with patch.object( + manager, "_DataManager__delete_data", side_effect=fake_delete + ): + manager.delete_data( + collection="TestCollection", + limit=5, + parallel_workers=4, + ) + + assert len(processed) == 1 + + +# --------------------------------------------------------------------------- +# create_data – parallel tenant processing +# --------------------------------------------------------------------------- + + +class TestCreateDataParallel: + def _make_col(self, tenant_names): + """MT collection with active tenant status ready for create_data.""" + col = _make_mt_col(tenant_names) + tenant_status = MagicMock() + tenant_status.activity_status = TenantActivityStatus.ACTIVE + col.tenants.get_by_name.return_value = tenant_status + return col + + def test_all_tenants_processed_in_parallel(self, mock_client): + """All tenants are processed when parallel_workers > 1.""" + manager = DataManager(mock_client) + col = self._make_col(["Tenant-0", "Tenant-1", "Tenant-2"]) + _setup_mock_client_with_col(mock_client, col) + + processed = [] + lock = threading.Lock() + + def fake_ingest(collection, **kwargs): + with lock: + processed.append(collection) + return collection + + with patch.object( + manager, "_DataManager__ingest_data", side_effect=fake_ingest + ): + manager.create_data( + collection="TestCollection", + limit=5, + parallel_workers=4, + ) + + assert len(processed) == 3 + + def test_sequential_when_parallel_workers_is_1(self, mock_client): + """When parallel_workers=1, all tenants are still processed (sequentially).""" + manager = DataManager(mock_client) + col = self._make_col(["Tenant-0", "Tenant-1", "Tenant-2"]) + _setup_mock_client_with_col(mock_client, col) + + processed = [] + + def fake_ingest(collection, **kwargs): + processed.append(collection) + return collection + + with patch.object( + manager, "_DataManager__ingest_data", side_effect=fake_ingest + ): + manager.create_data( + collection="TestCollection", + limit=5, + parallel_workers=1, + ) + + assert len(processed) == 3 + + def test_parallel_errors_collected_and_raised(self, mock_client): + """Errors from parallel tenant ingestion are collected and raised together.""" + manager = DataManager(mock_client) + col = self._make_col(["Tenant-0", "Tenant-1"]) + _setup_mock_client_with_col(mock_client, col) + + def fake_ingest(collection, **kwargs): + raise Exception("simulated ingest error") + + with patch.object( + manager, "_DataManager__ingest_data", side_effect=fake_ingest + ): + with pytest.raises( + Exception, match="Errors during parallel data ingestion" + ): + manager.create_data( + collection="TestCollection", + limit=5, + parallel_workers=4, + ) + + def test_non_mt_collection_unaffected(self, mock_client): + """Non-MT collections are processed as a single 'None' tenant.""" + manager = DataManager(mock_client) + col = _make_non_mt_col() + _setup_mock_client_with_col(mock_client, col) + + processed = [] + + def fake_ingest(collection, **kwargs): + processed.append(collection) + return collection + + with patch.object( + manager, "_DataManager__ingest_data", side_effect=fake_ingest + ): + manager.create_data( + collection="TestCollection", + limit=5, + parallel_workers=4, + ) + + # Single "None" pseudo-tenant processed + assert len(processed) == 1 + + +# --------------------------------------------------------------------------- +# create_data – concurrent_requests scaling with parallel_workers +# --------------------------------------------------------------------------- + + +class TestCreateDataConcurrentRequestsScaling: + def test_concurrent_requests_reduced_for_parallel_tenants(self, mock_client): + """When parallel_workers > 1 with multiple tenants, concurrent_requests + per tenant is divided to keep total connections bounded.""" + manager = DataManager(mock_client) + col = _make_mt_col(["Tenant-0", "Tenant-1"]) + col.tenants.exists.return_value = True + tenant_status = MagicMock() + tenant_status.activity_status = TenantActivityStatus.ACTIVE + col.tenants.get_by_name.return_value = tenant_status + tenant_col = MagicMock() + tenant_col.__len__ = MagicMock(return_value=0) + col.with_tenant.return_value = tenant_col + _setup_mock_client_with_col(mock_client, col) + + captured_concurrent = [] + lock = threading.Lock() + + def fake_ingest(collection, *, concurrent_requests, **kwargs): + with lock: + captured_concurrent.append(concurrent_requests) + return collection + + with patch.object( + manager, "_DataManager__ingest_data", side_effect=fake_ingest + ): + manager.create_data( + collection="TestCollection", + limit=10, + randomize=True, + tenant_suffix="Tenant", + concurrent_requests=8, + parallel_workers=4, + ) + + # actual_workers = min(parallel_workers, len(tenants), concurrent_requests) + # = min(4, 2, 8) = 2 + # effective_concurrent = concurrent_requests // actual_workers = 8 // 2 = 4 + expected = max(1, 8 // min(4, 2, 8)) + assert all(c == expected for c in captured_concurrent) + assert len(captured_concurrent) == 2 + + def test_concurrent_requests_unchanged_for_single_tenant(self, mock_client): + """With a single tenant, concurrent_requests is not reduced.""" + manager = DataManager(mock_client) + col = _make_mt_col(["Tenant-0"]) + col.tenants.exists.return_value = True + tenant_status = MagicMock() + tenant_status.activity_status = TenantActivityStatus.ACTIVE + col.tenants.get_by_name.return_value = tenant_status + tenant_col = MagicMock() + tenant_col.__len__ = MagicMock(return_value=0) + col.with_tenant.return_value = tenant_col + _setup_mock_client_with_col(mock_client, col) + + captured_concurrent = [] + + def fake_ingest(collection, *, concurrent_requests, **kwargs): + captured_concurrent.append(concurrent_requests) + return collection + + with patch.object( + manager, "_DataManager__ingest_data", side_effect=fake_ingest + ): + manager.create_data( + collection="TestCollection", + limit=10, + randomize=True, + tenant_suffix="Tenant", + concurrent_requests=8, + parallel_workers=4, + ) + + # Single tenant: no reduction + assert captured_concurrent == [8] diff --git a/test/unittests/test_managers/test_tenant_manager.py b/test/unittests/test_managers/test_tenant_manager.py index 34d9cd1..3f683ca 100644 --- a/test/unittests/test_managers/test_tenant_manager.py +++ b/test/unittests/test_managers/test_tenant_manager.py @@ -407,7 +407,7 @@ def test_happy_path_deletes_tenants_text_output( json_output=False, ) - mock_collection.tenants.remove.assert_called_once_with(mock_tenant) + mock_collection.tenants.remove.assert_called_once_with([mock_tenant]) out = capsys.readouterr().out assert "1" in out assert "deleted" in out @@ -459,7 +459,10 @@ def test_deletes_only_up_to_number_tenants( json_output=False, ) - assert mock_collection.tenants.remove.call_count == 2 + # Batch delete: single call with list of tenants to delete + mock_collection.tenants.remove.assert_called_once() + removed = mock_collection.tenants.remove.call_args[0][0] + assert len(removed) == 2 def test_deletes_using_tenants_list(self, mock_client: MagicMock, capsys) -> None: """When tenants_list is provided, get_by_names is used.""" @@ -483,7 +486,7 @@ def test_deletes_using_tenants_list(self, mock_client: MagicMock, capsys) -> Non # get() should NOT be called when tenants_list is provided mock_collection.tenants.get.assert_not_called() assert mock_collection.tenants.get_by_names.call_count == 2 - mock_collection.tenants.remove.assert_called_once_with(specific_tenant) + mock_collection.tenants.remove.assert_called_once_with([specific_tenant]) def test_delete_all_with_wildcard_suffix( self, mock_client: MagicMock, capsys @@ -505,7 +508,10 @@ def test_delete_all_with_wildcard_suffix( json_output=False, ) - assert mock_collection.tenants.remove.call_count == 2 + # Batch delete: single call with all tenants + mock_collection.tenants.remove.assert_called_once() + removed = mock_collection.tenants.remove.call_args[0][0] + assert len(removed) == 2 # --------------------------------------------------------------------------- @@ -844,5 +850,5 @@ def test_skips_update_call_when_tenant_already_in_desired_state( json_output=False, ) - # The existing tenant object should be passed directly (no new Tenant) - mock_collection.tenants.update.assert_called_once_with(already_active) + # Batch update: single call with list containing the existing tenant as-is + mock_collection.tenants.update.assert_called_once_with([already_active]) diff --git a/weaviate_cli/commands/create.py b/weaviate_cli/commands/create.py index 0276cf8..8952880 100644 --- a/weaviate_cli/commands/create.py +++ b/weaviate_cli/commands/create.py @@ -521,6 +521,12 @@ def create_backup_cli( type=int, help=f"Number of concurrent requests to send to the server (default: {MAX_WORKERS}).", ) +@click.option( + "--parallel_workers", + default=CreateDataDefaults.parallel_workers, + type=click.IntRange(min=1), + help=f"Number of tenants to process in parallel (default: {CreateDataDefaults.parallel_workers}). Set to 1 to disable parallelism.", +) @click.option( "--json", "json_output", is_flag=True, default=False, help="Output in JSON format." ) @@ -543,6 +549,7 @@ def create_data_cli( dynamic_batch, batch_size, concurrent_requests, + parallel_workers, json_output, ): """Ingest data into a collection in Weaviate.""" @@ -593,6 +600,7 @@ def create_data_cli( dynamic_batch=dynamic_batch, batch_size=batch_size, concurrent_requests=concurrent_requests, + parallel_workers=parallel_workers, json_output=json_output, ) except Exception as e: diff --git a/weaviate_cli/commands/delete.py b/weaviate_cli/commands/delete.py index d70add2..04427e2 100644 --- a/weaviate_cli/commands/delete.py +++ b/weaviate_cli/commands/delete.py @@ -159,12 +159,26 @@ def delete_tenants_cli( default=DeleteDataDefaults.verbose, help="Show detailed progress information (default: False).", ) +@click.option( + "--parallel_workers", + default=DeleteDataDefaults.parallel_workers, + type=click.IntRange(min=1), + help=f"Number of tenants to process in parallel (default: {DeleteDataDefaults.parallel_workers}). Set to 1 to disable parallelism.", +) @click.option( "--json", "json_output", is_flag=True, default=False, help="Output in JSON format." ) @click.pass_context def delete_data_cli( - ctx, collection, limit, consistency_level, tenants, uuid, verbose, json_output + ctx, + collection, + limit, + consistency_level, + tenants, + uuid, + verbose, + parallel_workers, + json_output, ): """Delete data from a collection in Weaviate.""" @@ -184,6 +198,7 @@ def delete_data_cli( ), uuid=uuid, verbose=verbose, + parallel_workers=parallel_workers, json_output=json_output, ) except Exception as e: diff --git a/weaviate_cli/commands/update.py b/weaviate_cli/commands/update.py index 71dbe1d..15c41b0 100644 --- a/weaviate_cli/commands/update.py +++ b/weaviate_cli/commands/update.py @@ -313,6 +313,12 @@ def update_shards_cli( default=UpdateDataDefaults.verbose, help="Show detailed progress information (default: True).", ) +@click.option( + "--parallel_workers", + default=UpdateDataDefaults.parallel_workers, + type=click.IntRange(min=1), + help=f"Number of tenants to process in parallel (default: {UpdateDataDefaults.parallel_workers}). Set to 1 to disable parallelism.", +) @click.option( "--json", "json_output", is_flag=True, default=False, help="Output in JSON format." ) @@ -325,6 +331,7 @@ def update_data_cli( randomize, skip_seed, verbose, + parallel_workers, json_output, ): """Update data in a collection in Weaviate.""" @@ -341,6 +348,7 @@ def update_data_cli( randomize=randomize, skip_seed=skip_seed, verbose=verbose, + parallel_workers=parallel_workers, json_output=json_output, ) except Exception as e: diff --git a/weaviate_cli/defaults.py b/weaviate_cli/defaults.py index 9d0fb4e..fcc9bc1 100644 --- a/weaviate_cli/defaults.py +++ b/weaviate_cli/defaults.py @@ -123,6 +123,7 @@ class CreateDataDefaults: multi_vector: bool = False batch_size: int = 1000 dynamic_batch: bool = False + parallel_workers: int = MAX_WORKERS @dataclass @@ -179,6 +180,7 @@ class DeleteDataDefaults: consistency_level: str = "quorum" uuid: Optional[str] = None verbose: bool = False + parallel_workers: int = MAX_WORKERS @dataclass @@ -287,6 +289,7 @@ class UpdateDataDefaults: randomize: bool = False skip_seed: bool = False verbose: bool = False + parallel_workers: int = MAX_WORKERS @dataclass diff --git a/weaviate_cli/managers/data_manager.py b/weaviate_cli/managers/data_manager.py index ff66e18..72e7962 100644 --- a/weaviate_cli/managers/data_manager.py +++ b/weaviate_cli/managers/data_manager.py @@ -2,8 +2,10 @@ import json import math import random +import threading import time from collections import deque +from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timedelta from typing import Dict, List, Optional, Union, Any, Tuple @@ -819,6 +821,7 @@ def create_data( dynamic_batch: bool = CreateDataDefaults.dynamic_batch, batch_size: int = CreateDataDefaults.batch_size, concurrent_requests: int = MAX_WORKERS, + parallel_workers: int = CreateDataDefaults.parallel_workers, json_output: bool = False, ) -> Collection: @@ -860,10 +863,26 @@ def create_data( if not json_output: click.echo(f"Preparing to insert {limit} objects into class '{col.name}'") total_inserted = 0 - for tenant in tenants: + + # Clamp actual thread count to the number of tenants (no point creating + # more threads than tasks) and to concurrent_requests (so the max(1,…) + # floor can't push total in-flight above the budget when + # parallel_workers > concurrent_requests). + actual_workers = min(parallel_workers, len(tenants), concurrent_requests) + effective_concurrent = ( + max(1, concurrent_requests // actual_workers) + if actual_workers > 1 + else concurrent_requests + ) + + _parallel_mode = actual_workers > 1 + _output_lock = threading.Lock() + + def _ingest_one_tenant(tenant: str): + """Ingest data for a single tenant; returns (inserted_count, collection).""" if tenant == "None": - initial_length = len(col) - collection = self.__ingest_data( + _initial = len(col) + _coll = self.__ingest_data( collection=col, num_objects=limit, cl=cl_map[consistency_level], @@ -875,10 +894,10 @@ def create_data( multi_vector=multi_vector, dynamic_batch=dynamic_batch, batch_size=batch_size, - concurrent_requests=concurrent_requests, + concurrent_requests=effective_concurrent, json_output=json_output, ) - after_length = len(col) + _after = len(col) else: if not auto_tenant_creation_enabled and not col.tenants.exists(tenant): raise Exception( @@ -894,12 +913,12 @@ def create_data( f"Tenant '{tenant}' is not active. Please activate it using command" ) if auto_tenant_creation_enabled and not col.tenants.exists(tenant): - initial_length = 0 + _initial = 0 else: - initial_length = len(col.with_tenant(tenant)) - if not json_output: + _initial = len(col.with_tenant(tenant)) + if not json_output and not _parallel_mode: click.echo(f"Processing objects for tenant '{tenant}'") - collection = self.__ingest_data( + _coll = self.__ingest_data( collection=col.with_tenant(tenant), num_objects=limit, cl=cl_map[consistency_level], @@ -911,18 +930,48 @@ def create_data( multi_vector=multi_vector, dynamic_batch=dynamic_batch, batch_size=batch_size, - concurrent_requests=concurrent_requests, + concurrent_requests=effective_concurrent, json_output=json_output, ) - after_length = len(col.with_tenant(tenant)) + _after = len(col.with_tenant(tenant)) if wait_for_indexing: - collection.batch.wait_for_vector_indexing() - inserted = after_length - initial_length - total_inserted += inserted - if inserted != limit: - click.echo( - f"Error occurred while ingesting data for tenant '{tenant}'. Expected number of objects inserted: {limit}. Actual number of objects inserted: {inserted}. Double check with weaviate-cli get collection" + _coll.batch.wait_for_vector_indexing() + _inserted = _after - _initial + if _inserted != limit: + with _output_lock: + click.echo( + f"Error occurred while ingesting data for tenant '{tenant}'. " + f"Expected number of objects inserted: {limit}. " + f"Actual number of objects inserted: {_inserted}. " + f"Double check with weaviate-cli get collection" + ) + return _inserted, _coll + + collection = col + if _parallel_mode: + _lock = threading.Lock() + _errors: List[str] = [] + with ThreadPoolExecutor(max_workers=actual_workers) as executor: + future_to_tenant = { + executor.submit(_ingest_one_tenant, t): t for t in tenants + } + for future in as_completed(future_to_tenant): + t = future_to_tenant[future] + try: + inserted, _coll = future.result() + with _lock: + total_inserted += inserted + except Exception as exc: + _errors.append(f"Tenant '{t}': {exc}") + if _errors: + raise Exception( + "Errors during parallel data ingestion:\n" + "\n".join(_errors) ) + else: + for tenant in tenants: + inserted, collection = _ingest_one_tenant(tenant) + total_inserted += inserted + if json_output: click.echo( json.dumps( @@ -1123,6 +1172,7 @@ def update_data( randomize: bool = UpdateDataDefaults.randomize, skip_seed: bool = UpdateDataDefaults.skip_seed, verbose: bool = UpdateDataDefaults.verbose, + parallel_workers: int = UpdateDataDefaults.parallel_workers, json_output: bool = False, ) -> None: @@ -1153,9 +1203,10 @@ def update_data( if not json_output: click.echo(f"Preparing to update {limit} objects into class '{col.name}'") total_updated = 0 - for tenant in tenants: + + def _update_one_tenant(tenant: str) -> int: if tenant == "None": - ret = self.__update_data( + return self.__update_data( col, limit, cl_map[consistency_level], @@ -1164,23 +1215,58 @@ def update_data( verbose, json_output=json_output, ) - else: - if not json_output: - click.echo(f"Processing tenant '{tenant}'") - ret = self.__update_data( - col.with_tenant(tenant), - limit, - cl_map[consistency_level], - randomize, - skip_seed, - verbose, - json_output=json_output, + if not json_output and parallel_workers <= 1: + click.echo(f"Processing tenant '{tenant}'") + return self.__update_data( + col.with_tenant(tenant), + limit, + cl_map[consistency_level], + randomize, + skip_seed, + verbose, + json_output=json_output, + ) + + if len(tenants) > 1 and parallel_workers > 1: + actual_workers = min(parallel_workers, len(tenants)) + _lock = threading.Lock() + _errors: List[str] = [] + with ThreadPoolExecutor(max_workers=actual_workers) as executor: + future_to_tenant = { + executor.submit(_update_one_tenant, t): t for t in tenants + } + for future in as_completed(future_to_tenant): + t = future_to_tenant[future] + try: + ret = future.result() + if ret == -1: + _errors.append( + f"Failed to update objects in class '{col.name}' for tenant '{t}'" + ) + else: + with _lock: + total_updated += ret + except Exception as exc: + _errors.append(f"Tenant '{t}': {exc}") + if _errors: + raise Exception( + "Errors during parallel data update:\n" + "\n".join(_errors) ) - if ret == -1: + else: + _errors: List[str] = [] + for tenant in tenants: + ret = _update_one_tenant(tenant) + if ret == -1: + _errors.append( + f"Failed to update objects in class '{col.name}' for tenant '{tenant}'" + ) + else: + total_updated += ret + if _errors: raise Exception( - f"Failed to update objects in class '{col.name}' for tenant '{tenant}'" + "Errors during sequential data update:\n" + "\n".join(_errors) ) - total_updated += ret + if json_output: click.echo( json.dumps( @@ -1279,6 +1365,7 @@ def delete_data( tenants_list: Optional[List[str]] = None, uuid: Optional[str] = DeleteDataDefaults.uuid, verbose: bool = DeleteDataDefaults.verbose, + parallel_workers: int = DeleteDataDefaults.parallel_workers, json_output: bool = False, ) -> None: @@ -1306,27 +1393,45 @@ def delete_data( total_deleted = 0 - for tenant in tenants: + def _delete_one_tenant(tenant: str) -> int: if tenant == "None": - ret = self.__delete_data( # NOTE: call the correct delete impl + return self.__delete_data( col, limit, cl_map[consistency_level], uuid, verbose, json_output ) - else: - if not json_output: - click.echo(f"Processing tenant '{tenant}'") - ret = self.__delete_data( - col.with_tenant(tenant), - limit, - cl_map[consistency_level], - uuid, - verbose, - json_output, - ) - if ret == -1: + if not json_output and parallel_workers <= 1: + click.echo(f"Processing tenant '{tenant}'") + return self.__delete_data( + col.with_tenant(tenant), + limit, + cl_map[consistency_level], + uuid, + verbose, + json_output, + ) + + if len(tenants) > 1 and parallel_workers > 1: + actual_workers = min(parallel_workers, len(tenants)) + _lock = threading.Lock() + _errors: List[str] = [] + with ThreadPoolExecutor(max_workers=actual_workers) as executor: + future_to_tenant = { + executor.submit(_delete_one_tenant, t): t for t in tenants + } + for future in as_completed(future_to_tenant): + t = future_to_tenant[future] + try: + with _lock: + total_deleted += future.result() + except Exception as exc: + _errors.append(f"Tenant '{t}': {exc}") + if _errors: raise Exception( - f"Failed to delete objects in class '{col.name}' for tenant '{tenant}'" + "Errors during parallel data deletion:\n" + "\n".join(_errors) ) - total_deleted += ret + else: + for tenant in tenants: + total_deleted += _delete_one_tenant(tenant) + if json_output: click.echo( json.dumps( diff --git a/weaviate_cli/managers/tenant_manager.py b/weaviate_cli/managers/tenant_manager.py index 2c70331..724dce0 100644 --- a/weaviate_cli/managers/tenant_manager.py +++ b/weaviate_cli/managers/tenant_manager.py @@ -266,9 +266,8 @@ def delete_tenants( if not deleting_tenants: raise Exception(f"No tenants present in class {collection.name}.") - - for tenant in deleting_tenants.values(): - collection.tenants.remove(tenant) + else: + collection.tenants.remove(list(deleting_tenants.values())) except Exception as e: raise Exception(f"Failed to delete tenants: {e}") @@ -472,12 +471,19 @@ def update_tenants( else: existing_tenants = dict(list(tenants_with_suffix.items())[:number_tenants]) - for name, tenant in existing_tenants.items(): - collection.tenants.update( + tenants_to_update = [ + ( Tenant(name=name, activity_status=tenant_state_map[state]) if tenant.activity_status != tenant_state_map[state] else tenant ) + for name, tenant in existing_tenants.items() + ] + if not tenants_to_update: + raise Exception( + f"No matching tenants found in collection '{collection.name}' to update." + ) + collection.tenants.update(tenants_to_update) # get_by_names is only available after 1.25.0 if version.compare(semver.Version.parse("1.25.0")) < 0: