From 0dff9fd4150b6e5880785e5be2fdfcb99ef1756b Mon Sep 17 00:00:00 2001 From: nnkhanhduy Date: Fri, 13 Mar 2026 16:43:34 +0700 Subject: [PATCH 1/4] Add support for multiple Layer 2 devices --- other/DCSL/src/RpcClient.py | 3 ++- other/DCSL/src/Scheduler.py | 42 +++++++++++++++++-------------------- other/DCSL/src/Server.py | 8 +++++-- 3 files changed, 27 insertions(+), 26 deletions(-) diff --git a/other/DCSL/src/RpcClient.py b/other/DCSL/src/RpcClient.py index fde4e5a..b4aea60 100644 --- a/other/DCSL/src/RpcClient.py +++ b/other/DCSL/src/RpcClient.py @@ -99,6 +99,7 @@ def response_message(self, body): lr = self.response["lr"] momentum = self.response["momentum"] sda_size = self.response.get("sda_size", 1) + layer2_devices = self.response.get("layer2_devices", []) # Read parameters and load to model if state_dict: @@ -113,7 +114,7 @@ def response_message(self, body): subset = torch.utils.data.Subset(self.train_set, selected_indices) train_loader = torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=True) - result, size = self.train_func(self.model, lr, momentum, train_loader, local_round=local_round) + result, size = self.train_func(self.model, lr, momentum, train_loader, local_round=local_round, layer2_devices=layer2_devices) else: result, size = self.train_func(self.model, lr, momentum, None, local_round=local_round, sda_size=sda_size) diff --git a/other/DCSL/src/Scheduler.py b/other/DCSL/src/Scheduler.py index 6a85bd4..6dce01b 100644 --- a/other/DCSL/src/Scheduler.py +++ b/other/DCSL/src/Scheduler.py @@ -18,9 +18,12 @@ def __init__(self, client_id, layer_id, channel, device): self.device = device self.data_count = 0 - def send_intermediate_output(self, output, labels, trace, data_id=None): + def send_intermediate_output(self, output, labels, trace, data_id=None, target_device_id=None): - forward_queue_name = f'intermediate_queue_{self.layer_id}' + if target_device_id is not None: + forward_queue_name = f'intermediate_queue_{target_device_id}' + else: + forward_queue_name = f'intermediate_queue_{self.layer_id}' self.channel.queue_declare(forward_queue_name, durable=False) @@ -63,11 +66,7 @@ def send_to_server(self, message): routing_key='rpc_queue', body=pickle.dumps(message)) - def train_on_first_layer(self, model, lr, momentum, train_loader=None, local_round=3): - """ - Synchronous training: forward 1 batch → wait for gradient → backward → next batch. - Edge device does NOT send multiple batches before receiving gradient. - """ + def train_on_first_layer(self, model, lr, momentum, train_loader=None, local_round=3, layer2_devices=None): optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) backward_queue_name = f'gradient_queue_{self.layer_id}_{self.client_id}' @@ -76,6 +75,8 @@ def train_on_first_layer(self, model, lr, momentum, train_loader=None, local_rou model.to(self.device) + batch_counter = 0 + for i in range(local_round): src.Log.print_with_color(f'Epoch {i}', 'green') @@ -91,8 +92,13 @@ def train_on_first_layer(self, model, lr, momentum, train_loader=None, local_rou self.data_count += 1 pbar.update(1) - # Step 2: Send smashed data to server - self.send_intermediate_output(intermediate_output, labels, trace=None, data_id=data_id) + # Step 2: Send smashed data to target layer-2 device (round-robin) + target_device_id = None + if layer2_devices: + target_device_id = layer2_devices[batch_counter % len(layer2_devices)] + batch_counter += 1 + + self.send_intermediate_output(intermediate_output, labels, trace=None, data_id=data_id, target_device_id=target_device_id) # Step 3: Wait for gradient (blocking) while True: @@ -129,11 +135,6 @@ def train_on_first_layer(self, model, lr, momentum, train_loader=None, local_rou time.sleep(0.5) def _process_sda_batch(self, model, optimizer, criterion, collected): - """ - SDA (Smashed Data Aggregation) — Eq. 4-5 from paper. - Concatenate smashed data from all clients, forward once, - split gradient back to each client. - """ batch_sizes = [item["data"].shape[0] for item in collected] traces = [item["trace"] for item in collected] data_ids = [item["data_id"] for item in collected] @@ -174,19 +175,14 @@ def _process_sda_batch(self, model, optimizer, criterion, collected): return result def train_on_last_layer(self, model, lr, momentum, sda_size=1): - """ - SDA: collect exactly 1 batch from each client, - concat and forward once, split gradient back. - Since edge devices are synchronous, no overflow needed. - """ optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) result = True criterion = nn.CrossEntropyLoss() - forward_queue_name = f'intermediate_queue_{self.layer_id - 1}' + forward_queue_name = f'intermediate_queue_{self.client_id}' self.channel.queue_declare(queue=forward_queue_name, durable=False) self.channel.basic_qos(prefetch_count=1) - print(f'Waiting for intermediate output (SDA size={sda_size}). To exit press CTRL+C') + print(f'Waiting for intermediate output on queue {forward_queue_name} (SDA size={sda_size}). To exit press CTRL+C') model.to(self.device) sda_batch = {} # {client_id: data} — exactly 1 batch per client @@ -219,10 +215,10 @@ def train_on_last_layer(self, model, lr, momentum, sda_size=1): result = False return result - def train_on_device(self, model, lr, momentum, train_loader=None, local_round=None, sda_size=1): + def train_on_device(self, model, lr, momentum, train_loader=None, local_round=None, sda_size=1, layer2_devices=None): self.data_count = 0 if self.layer_id == 1: - result = self.train_on_first_layer(model, lr, momentum, train_loader, local_round) + result = self.train_on_first_layer(model, lr, momentum, train_loader, local_round, layer2_devices=layer2_devices) else: result = self.train_on_last_layer(model, lr, momentum, sda_size) diff --git a/other/DCSL/src/Server.py b/other/DCSL/src/Server.py index 3468a27..3a08146 100644 --- a/other/DCSL/src/Server.py +++ b/other/DCSL/src/Server.py @@ -244,6 +244,8 @@ def on_request(self, ch, method, props, body): def notify_clients(self, start=True, register=True, idx=0, avg_model=None): if start: + layer2_device_ids = [str(cid) for (cid, lid, cl) in self.list_clients if lid == 2] + if register: klass = globals()[f'{self.model_name}_{self.data_name}'] @@ -298,7 +300,8 @@ def notify_clients(self, start=True, register=True, idx=0, avg_model=None): "momentum": self.momentum, "label_count": label, "local_round": self.local_round, - "sda_size": self.sda_size + "sda_size": self.sda_size, + "layer2_devices": layer2_device_ids } self.send_to_response(client_id, pickle.dumps(response)) @@ -320,7 +323,8 @@ def notify_clients(self, start=True, register=True, idx=0, avg_model=None): "momentum": self.momentum, "label_count": label, "local_round": self.local_round, - "sda_size": self.sda_size + "sda_size": self.sda_size, + "layer2_devices": layer2_device_ids } self.send_to_response(client_id, pickle.dumps(response)) From 7d98b9e7b3ce4bc1569c3d24a0c75234b9b5413d Mon Sep 17 00:00:00 2001 From: nnkhanhduy Date: Mon, 16 Mar 2026 17:43:14 +0700 Subject: [PATCH 2/4] Update Bert and kwt --- other/DCSL/src/RpcClient.py | 32 ++- other/DCSL/src/Server.py | 31 ++- other/DCSL/src/Validation.py | 11 ++ other/DCSL/src/dataset/EMOTION.py | 73 +++++++ other/DCSL/src/dataset/SPEECHCOMMANDS.py | 215 +++++++++++++++++++++ other/DCSL/src/model/KWT_SPEECHCOMMANDS.py | 125 ++++++++++++ other/DCSL/src/model/__init__.py | 2 + 7 files changed, 477 insertions(+), 12 deletions(-) create mode 100644 other/DCSL/src/dataset/EMOTION.py create mode 100644 other/DCSL/src/dataset/SPEECHCOMMANDS.py create mode 100644 other/DCSL/src/model/KWT_SPEECHCOMMANDS.py diff --git a/other/DCSL/src/RpcClient.py b/other/DCSL/src/RpcClient.py index b4aea60..ce7aa70 100644 --- a/other/DCSL/src/RpcClient.py +++ b/other/DCSL/src/RpcClient.py @@ -75,13 +75,41 @@ def response_message(self, body): ]) self.train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + elif data_name == "SPEECHCOMMANDS": + from src.dataset.SPEECHCOMMANDS import SpeechCommandsDataset + self.train_set = SpeechCommandsDataset(root='./data', subset='training') + elif data_name == "EMOTION": + from datasets import load_dataset + from transformers import BertTokenizer + from src.dataset.EMOTION import EMOTIONDataset + + dataset = load_dataset('ag_news', download_mode='reuse_dataset_if_exists', cache_dir='./hf_cache') + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + + train_data = dataset['train'] + texts = train_data['text'] + labels = train_data['label'] + + self.train_set = EMOTIONDataset(texts, labels, tokenizer, max_length=128) else: self.train_set = None raise ValueError(f"Data name '{data_name}' is not valid.") self.label_to_indices = defaultdict(list) - for idx, (_, label) in tqdm(enumerate(self.train_set)): - self.label_to_indices[int(label)].append(idx) + if data_name == "EMOTION": + for idx, label in enumerate(self.train_set.labels): + self.label_to_indices[int(label)].append(idx) + elif data_name == "SPEECHCOMMANDS": + from src.dataset.SPEECHCOMMANDS import CLASSES + for idx, (audio_path, label_name) in enumerate(self.train_set.samples): + if label_name in CLASSES: + label_idx = CLASSES.index(label_name) + else: + label_idx = CLASSES.index('unknown') + self.label_to_indices[label_idx].append(idx) + else: + for idx, (_, label) in tqdm(enumerate(self.train_set)): + self.label_to_indices[int(label)].append(idx) # Load model if self.model is None: diff --git a/other/DCSL/src/Server.py b/other/DCSL/src/Server.py index 3a08146..3c36c94 100644 --- a/other/DCSL/src/Server.py +++ b/other/DCSL/src/Server.py @@ -89,16 +89,27 @@ def distribution(self): if self.non_iid: # label_distribution = np.random.dirichlet([self.data_distribution["dirichlet"]["alpha"]] * self.num_label, # self.total_clients[0]) - - label_distribution = np.array([[0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], - [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], - [0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], - [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], - [0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], - [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], + #VGG16 + # label_distribution = np.array([[0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], + # [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], + # [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], + # [0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], + # [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], + # [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], + # [0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], + # [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], + # [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], + # ]) + #KWT + label_distribution = np.array([[0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25], + [0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25], + [0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25], ]) self.label_counts = (label_distribution * self.num_sample).astype(int) diff --git a/other/DCSL/src/Validation.py b/other/DCSL/src/Validation.py index eb7dd7a..a5693f1 100644 --- a/other/DCSL/src/Validation.py +++ b/other/DCSL/src/Validation.py @@ -24,6 +24,17 @@ def test(model_name, data_name, state_dict_full, logger): transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + elif data_name == "SPEECHCOMMANDS": + from src.dataset.SPEECHCOMMANDS import SpeechCommandsDataset + testset = SpeechCommandsDataset(root='./data', subset='testing') + elif data_name == "EMOTION": + from datasets import load_dataset + from transformers import BertTokenizer + from src.dataset.EMOTION import EMOTIONDataset, load_test_EMOTION + dataset = load_dataset('ag_news', download_mode='reuse_dataset_if_exists', cache_dir='./hf_cache') + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + test_texts, test_labels = load_test_EMOTION(1000, dataset) + testset = EMOTIONDataset(test_texts, test_labels, tokenizer, max_length=128) else: raise ValueError(f"Data name '{data_name}' is not valid.") diff --git a/other/DCSL/src/dataset/EMOTION.py b/other/DCSL/src/dataset/EMOTION.py new file mode 100644 index 0000000..b24b487 --- /dev/null +++ b/other/DCSL/src/dataset/EMOTION.py @@ -0,0 +1,73 @@ +import torch +import random +from collections import defaultdict + +class EMOTIONDataset(torch.utils.data.Dataset): + def __init__(self, texts, labels, tokenizer, max_length=128): + self.texts = texts + self.labels = labels + self.tokenizer = tokenizer + self.max_length = max_length + + def __len__(self): + return len(self.texts) + + def __getitem__(self, idx): + text = str(self.texts[idx]) + label = self.labels[idx] + + # Tokenize + encoding = self.tokenizer( + text, + truncation=True, + padding='max_length', + max_length=self.max_length, + return_tensors='pt' + ) + + return encoding['input_ids'].flatten(), torch.tensor(label, dtype=torch.long) + +def load_train_EMOTION(dataset=None, distribution=None): + random.seed(1) + + train_data = dataset['train'] + + train_target_counts = {k: v for k, v in enumerate(distribution)} + + train_by_class = defaultdict(list) + for text, label in zip(train_data['text'], train_data['label']): + train_by_class[label].append((text, label)) + + train_texts, train_labels = [], [] + for label, count in train_target_counts.items(): + samples = random.sample(train_by_class[label], count) + train_texts.extend([t for t, _ in samples]) + train_labels.extend([l for _, l in samples]) + + print("Train samples:", len(train_texts), {l: train_labels.count(l) for l in set(train_labels)}) + + return train_texts, train_labels + +def load_test_EMOTION(test_total=1000, dataset=None): + random.seed(1) + test_data = dataset['test'] + class_distribution = { + 0: 0.25, + 1: 0.25, + 2: 0.25, + 3: 0.25 + } + test_target_counts = {k: int(v * test_total) for k, v in class_distribution.items()} + test_by_class = defaultdict(list) + for text, label in zip(test_data['text'], test_data['label']): + test_by_class[label].append((text, label)) + + test_texts, test_labels = [], [] + for label, count in test_target_counts.items(): + samples = random.sample(test_by_class[label], count) + test_texts.extend([t for t, _ in samples]) + test_labels.extend([l for _, l in samples]) + + print("Test samples:", len(test_texts), {l: test_labels.count(l) for l in set(test_labels)}) + + return test_texts, test_labels \ No newline at end of file diff --git a/other/DCSL/src/dataset/SPEECHCOMMANDS.py b/other/DCSL/src/dataset/SPEECHCOMMANDS.py new file mode 100644 index 0000000..b50f27e --- /dev/null +++ b/other/DCSL/src/dataset/SPEECHCOMMANDS.py @@ -0,0 +1,215 @@ +import os +import random +import torch +import numpy as np +from torch.utils.data import Dataset +from scipy.io import wavfile +from scipy.fftpack import dct + +# 12 classes standard (10 keywords + silence + unknown) +CLASSES = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go', 'silence', 'unknown'] + +def compute_mfcc(waveform, sample_rate=16000, n_mfcc=40, n_fft=480, hop_length=160, n_mels=40): + """Compute MFCC features using scipy (no torchaudio backend needed)""" + # Pre-emphasis + emphasized = np.append(waveform[0], waveform[1:] - 0.97 * waveform[:-1]) + + # Framing + frame_length = n_fft + num_frames = 1 + (len(emphasized) - frame_length) // hop_length + frames = np.zeros((num_frames, frame_length)) + for i in range(num_frames): + frames[i] = emphasized[i * hop_length : i * hop_length + frame_length] + + # Windowing + frames *= np.hamming(frame_length) + + # FFT + mag_frames = np.absolute(np.fft.rfft(frames, n_fft)) + pow_frames = (1.0 / n_fft) * (mag_frames ** 2) + + # Mel filterbank + low_freq_mel = 0 + high_freq_mel = 2595 * np.log10(1 + (sample_rate / 2) / 700) + mel_points = np.linspace(low_freq_mel, high_freq_mel, n_mels + 2) + hz_points = 700 * (10 ** (mel_points / 2595) - 1) + bin_points = np.floor((n_fft + 1) * hz_points / sample_rate).astype(int) + + fbank = np.zeros((n_mels, int(np.floor(n_fft / 2 + 1)))) + for m in range(1, n_mels + 1): + f_m_minus = bin_points[m - 1] + f_m = bin_points[m] + f_m_plus = bin_points[m + 1] + for k in range(f_m_minus, f_m): + fbank[m - 1, k] = (k - f_m_minus) / (f_m - f_m_minus) + for k in range(f_m, f_m_plus): + fbank[m - 1, k] = (f_m_plus - k) / (f_m_plus - f_m) + + filter_banks = np.dot(pow_frames, fbank.T) + filter_banks = np.where(filter_banks == 0, np.finfo(float).eps, filter_banks) + filter_banks = 20 * np.log10(filter_banks) + + # DCT to get MFCCs + mfcc = dct(filter_banks, type=2, axis=1, norm='ortho')[:, :n_mfcc] + + return mfcc.T # (n_mfcc, time_steps) + + +def _load_background_noise(root): + """Load all background noise files from _background_noise_ directory""" + noise_dir = os.path.join(root, '_background_noise_') + noise_data = [] + if not os.path.exists(noise_dir): + print(f"[WARNING] Background noise dir not found: {noise_dir}") + return noise_data + + for f in os.listdir(noise_dir): + if not f.endswith('.wav'): + continue + fpath = os.path.join(noise_dir, f) + try: + sr, wav = wavfile.read(fpath) + if wav.dtype == np.int16: + wav = wav.astype(np.float32) / 32768.0 + elif wav.dtype == np.int32: + wav = wav.astype(np.float32) / 2147483648.0 + else: + wav = wav.astype(np.float32) + noise_data.append(wav) + except Exception as e: + print(f"[WARNING] Error reading noise file {fpath}: {e}") + + print(f"Loaded {len(noise_data)} background noise files") + return noise_data + + +class SpeechCommandsDataset(Dataset): + """Google Speech Commands V2 dataset with MFCC features for KWT model""" + + def __init__(self, root='./data', subset='training', n_silence=2300): + """ + Args: + root: Data directory + subset: 'training', 'validation', or 'testing' + n_silence: Number of silence samples to generate (training). + For val/test, uses n_silence // 9 (~260). + """ + self.root = os.path.join(root, 'SpeechCommands', 'speech_commands_v0.02') + self.subset = subset + self.samples = [] # list of (path_or_None, label_str) + self.noise_data = [] # background noise waveforms for silence + + # Check if dataset exists + if not os.path.exists(self.root): + raise RuntimeError(f"Dataset not found at {self.root}. Please download manually.") + + # Load background noise for silence class + self.noise_data = _load_background_noise(self.root) + + # Load validation/testing list + val_list = set() + test_list = set() + + val_file = os.path.join(self.root, 'validation_list.txt') + test_file = os.path.join(self.root, 'testing_list.txt') + + if os.path.exists(val_file): + with open(val_file) as f: + val_list = set(line.strip() for line in f) + if os.path.exists(test_file): + with open(test_file) as f: + test_list = set(line.strip() for line in f) + + # Collect keyword & unknown samples based on subset + for label_dir in os.listdir(self.root): + label_path = os.path.join(self.root, label_dir) + if not os.path.isdir(label_path) or label_dir.startswith('_'): + continue + + for audio_file in os.listdir(label_path): + if not audio_file.endswith('.wav'): + continue + + rel_path = os.path.join(label_dir, audio_file) + + if subset == 'validation' and rel_path in val_list: + self.samples.append((os.path.join(label_path, audio_file), label_dir)) + elif subset == 'testing' and rel_path in test_list: + self.samples.append((os.path.join(label_path, audio_file), label_dir)) + elif subset == 'training' and rel_path not in val_list and rel_path not in test_list: + self.samples.append((os.path.join(label_path, audio_file), label_dir)) + + # Generate silence samples (random 1s crops from background noise) + # Following the paper: silence class created from _background_noise_ + if self.noise_data: + if subset == 'training': + num_silence = n_silence + else: + num_silence = max(1, n_silence // 9) # ~260 for val/test + + for _ in range(num_silence): + # path=None marks this as a silence sample (generated on-the-fly) + self.samples.append((None, 'silence')) + + print(f"Added {num_silence} silence samples from background noise") + + print(f"Total {len(self.samples)} samples for {subset}") + + def _get_silence_waveform(self): + """Generate a 1-second silence waveform by random crop from background noise""" + target_length = 16000 + noise = random.choice(self.noise_data) + + if len(noise) <= target_length: + # Noise file too short, pad with zeros + waveform = np.pad(noise, (0, max(0, target_length - len(noise)))) + else: + start = random.randint(0, len(noise) - target_length) + waveform = noise[start : start + target_length] + + return waveform + + def __getitem__(self, idx): + audio_path, label = self.samples[idx] + + try: + if audio_path is None: + # Silence sample: random crop from background noise + waveform = self._get_silence_waveform() + else: + # Load audio using scipy + sample_rate, waveform = wavfile.read(audio_path) + + # Convert to float and normalize + if waveform.dtype == np.int16: + waveform = waveform.astype(np.float32) / 32768.0 + elif waveform.dtype == np.int32: + waveform = waveform.astype(np.float32) / 2147483648.0 + else: + waveform = waveform.astype(np.float32) + + # Pad/trim to 1 second (16000 samples) + target_length = 16000 + if len(waveform) < target_length: + waveform = np.pad(waveform, (0, target_length - len(waveform))) + else: + waveform = waveform[:target_length] + + # Compute MFCC features + mfcc = compute_mfcc(waveform, sample_rate=16000, n_mfcc=40) + mfcc = torch.tensor(mfcc, dtype=torch.float32) + + except Exception as e: + print(f"[WARNING] Error processing sample {idx}: {e}, using zeros") + mfcc = torch.zeros(40, 98, dtype=torch.float32) + + # Map label to class index + if label in CLASSES: + label_idx = CLASSES.index(label) + else: + label_idx = CLASSES.index('unknown') + + return mfcc, label_idx + + def __len__(self): + return len(self.samples) diff --git a/other/DCSL/src/model/KWT_SPEECHCOMMANDS.py b/other/DCSL/src/model/KWT_SPEECHCOMMANDS.py new file mode 100644 index 0000000..4a53f33 --- /dev/null +++ b/other/DCSL/src/model/KWT_SPEECHCOMMANDS.py @@ -0,0 +1,125 @@ +import torch +import torch.nn as nn + + +class TransformerEncoderBlock(nn.Module): + def __init__(self, embed_dim, num_heads=1, mlp_dim=256): + super().__init__() + self.ln1 = nn.LayerNorm(embed_dim) + self.mha = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) + self.ln2 = nn.LayerNorm(embed_dim) + self.mlp = nn.Sequential( + nn.Linear(embed_dim, mlp_dim), + nn.GELU(), + nn.Linear(mlp_dim, embed_dim) + ) + + def forward(self, x): + # Attention + residual + _x = self.ln1(x) + x_attn = self.mha(_x, _x, _x)[0] + x = x + x_attn + # MLP + residual + x_mlp = self.mlp(self.ln2(x)) + x = x + x_mlp + return x + + +class KWT_SPEECHCOMMANDS(nn.Module): + """ + Keyword Transformer (KWT-1) for Speech Commands - Split Learning version + KWT-1: dim=64, mlp_dim=256, heads=1, layers=12 + + Layers: + 1: Linear embedding (n_mfcc -> embed_dim) + 2: CLS token concatenation + 3: Positional embedding + Dropout + 4-15: 12x Transformer encoder blocks + 16: LayerNorm + 17: Classification head + """ + + def __init__(self, start_layer=0, end_layer=17): + super().__init__() + self.start_layer = start_layer + # Handle -1 as "all remaining layers" + self.end_layer = 17 if end_layer == -1 else end_layer + + # KWT-1 config (matches Colab) + n_mfcc = 40 + time_steps = 98 # (16000 - 480) / 160 + 1 ≈ 98 + embed_dim = 64 + num_heads = 1 + mlp_dim = 256 + num_classes = 12 + dropout = 0.1 + + # Layer 1: Linear embedding + if self.start_layer < 1 <= self.end_layer: + self.layer1 = nn.Linear(n_mfcc, embed_dim) + + # Layer 2: CLS token + if self.start_layer < 2 <= self.end_layer: + self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) + + # Layer 3: Positional embedding + Dropout + if self.start_layer < 3 <= self.end_layer: + self.pos_embed = nn.Parameter(torch.randn(1, time_steps + 1, embed_dim)) + self.dropout = nn.Dropout(dropout) + + # Layers 4-15: 12x Transformer encoder blocks + for i in range(12): + layer_idx = 4 + i + if self.start_layer < layer_idx <= self.end_layer: + setattr(self, f'layer{layer_idx}', + TransformerEncoderBlock(embed_dim, num_heads, mlp_dim)) + + # Layer 16: LayerNorm + if self.start_layer < 16 <= self.end_layer: + self.layer16 = nn.LayerNorm(embed_dim) + + # Layer 17: Classification head + if self.start_layer < 17 <= self.end_layer: + self.layer17 = nn.Linear(embed_dim, num_classes) + + self._init_weights() + + def _init_weights(self): + if hasattr(self, 'cls_token'): + nn.init.trunc_normal_(self.cls_token, std=0.02) + if hasattr(self, 'pos_embed'): + nn.init.trunc_normal_(self.pos_embed, std=0.02) + + def forward(self, x): + # Input x shape: (batch, n_mfcc, time_steps) + + # Layer 1: Linear embedding + if self.start_layer < 1 <= self.end_layer: + x = x.transpose(1, 2) # (batch, time_steps, n_mfcc) + x = self.layer1(x) # (batch, time_steps, embed_dim) + + # Layer 2: CLS token + if self.start_layer < 2 <= self.end_layer: + cls_token = self.cls_token.expand(x.size(0), -1, -1) + x = torch.cat([cls_token, x], dim=1) # (batch, time_steps+1, embed_dim) + + # Layer 3: Positional embedding + Dropout + if self.start_layer < 3 <= self.end_layer: + x = x + self.pos_embed + x = self.dropout(x) + + # Layers 4-15: 12x Transformer blocks + for i in range(12): + layer_idx = 4 + i + if self.start_layer < layer_idx <= self.end_layer: + x = getattr(self, f'layer{layer_idx}')(x) + + # Layer 16: LayerNorm on CLS token + if self.start_layer < 16 <= self.end_layer: + x = self.layer16(x[:, 0]) + + # Layer 17: Classification + if self.start_layer < 17 <= self.end_layer: + x = self.layer17(x) + + return x diff --git a/other/DCSL/src/model/__init__.py b/other/DCSL/src/model/__init__.py index b0304dd..8353b7e 100644 --- a/other/DCSL/src/model/__init__.py +++ b/other/DCSL/src/model/__init__.py @@ -4,3 +4,5 @@ from .VGG16_MNIST import * from .ViT_CIFAR10 import * from .ViT_MNIST import * +from .BERT_EMOTION import * +from .KWT_SPEECHCOMMANDS import * From 03b44f9a3b562529381ed9f11bbca26b365e9d5f Mon Sep 17 00:00:00 2001 From: nnkhanhduy Date: Sat, 21 Mar 2026 10:26:12 +0700 Subject: [PATCH 3/4] feat: Add Speech Commands and AG News datasets with their respective KWT and BERT models, alongside RPC client/server, validation, and scheduler modules. --- other/DCSL/src/RpcClient.py | 37 +- other/DCSL/src/Scheduler.py | 61 ++- other/DCSL/src/Server.py | 30 +- other/DCSL/src/Validation.py | 37 +- .../src/dataset/{EMOTION.py => AGNEWS.py} | 48 +- other/DCSL/src/dataset/SPEECHCOMMANDS.py | 8 - other/DCSL/src/model/BERT_AGNEWS.py | 221 +++++++++ other/DCSL/src/model/BERT_EMOTION.py | 428 ------------------ other/DCSL/src/model/KWT_SPEECHCOMMANDS.py | 13 +- other/DCSL/src/model/__init__.py | 2 +- 10 files changed, 361 insertions(+), 524 deletions(-) rename other/DCSL/src/dataset/{EMOTION.py => AGNEWS.py} (51%) create mode 100644 other/DCSL/src/model/BERT_AGNEWS.py delete mode 100644 other/DCSL/src/model/BERT_EMOTION.py diff --git a/other/DCSL/src/RpcClient.py b/other/DCSL/src/RpcClient.py index ce7aa70..37b4a5d 100644 --- a/other/DCSL/src/RpcClient.py +++ b/other/DCSL/src/RpcClient.py @@ -11,6 +11,8 @@ import src.Log from src.model import * +from peft import LoraConfig, get_peft_model + class RpcClient: def __init__(self, client_id, layer_id, channel, train_func, device): @@ -23,6 +25,7 @@ def __init__(self, client_id, layer_id, channel, train_func, device): self.response = None self.model = None self.label_count = None + self.peft_config = None self.train_set = None self.label_to_indices = None @@ -78,25 +81,25 @@ def response_message(self, body): elif data_name == "SPEECHCOMMANDS": from src.dataset.SPEECHCOMMANDS import SpeechCommandsDataset self.train_set = SpeechCommandsDataset(root='./data', subset='training') - elif data_name == "EMOTION": + elif data_name == "AGNEWS": from datasets import load_dataset from transformers import BertTokenizer - from src.dataset.EMOTION import EMOTIONDataset + from src.dataset.AGNEWS import AGNEWS_DATASET dataset = load_dataset('ag_news', download_mode='reuse_dataset_if_exists', cache_dir='./hf_cache') - tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + tokenizer = BertTokenizer.from_pretrained('bert-base-cased') train_data = dataset['train'] texts = train_data['text'] labels = train_data['label'] - self.train_set = EMOTIONDataset(texts, labels, tokenizer, max_length=128) + self.train_set = AGNEWS_DATASET(texts, labels, tokenizer, max_length=128) else: self.train_set = None raise ValueError(f"Data name '{data_name}' is not valid.") self.label_to_indices = defaultdict(list) - if data_name == "EMOTION": + if data_name == "AGNEWS": for idx, label in enumerate(self.train_set.labels): self.label_to_indices[int(label)].append(idx) elif data_name == "SPEECHCOMMANDS": @@ -133,6 +136,22 @@ def response_message(self, body): if state_dict: self.model.load_state_dict(state_dict) + # Apply LoRA for BERT model + if model_name == 'BERT': + if self.peft_config is None: + self.peft_config = LoraConfig( + task_type="SEQ_CLS", + r=8, lora_alpha=16, lora_dropout=0.1, + bias="none", + target_modules=["query", "key", "value", "dense"] + ) + self.model = get_peft_model(self.model, self.peft_config) + if self.layer_id == 2: + for param in self.model.layer15.parameters(): + param.requires_grad = True + + self.model.to(self.device) + # Start training if self.layer_id == 1: selected_indices = [] @@ -142,10 +161,14 @@ def response_message(self, body): subset = torch.utils.data.Subset(self.train_set, selected_indices) train_loader = torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=True) - result, size = self.train_func(self.model, lr, momentum, train_loader, local_round=local_round, layer2_devices=layer2_devices) + result, size = self.train_func(self.model, lr, momentum, train_loader, local_round=local_round, layer2_devices=layer2_devices, model_name=model_name) else: - result, size = self.train_func(self.model, lr, momentum, None, local_round=local_round, sda_size=sda_size) + result, size = self.train_func(self.model, lr, momentum, None, local_round=local_round, sda_size=sda_size, model_name=model_name) + + # Merge LoRA weights back for BERT + if model_name == 'BERT': + self.model = self.model.merge_and_unload() # Stop training, then send parameters to server model_state_dict = copy.deepcopy(self.model.state_dict()) diff --git a/other/DCSL/src/Scheduler.py b/other/DCSL/src/Scheduler.py index 6dce01b..60d3bb4 100644 --- a/other/DCSL/src/Scheduler.py +++ b/other/DCSL/src/Scheduler.py @@ -66,8 +66,11 @@ def send_to_server(self, message): routing_key='rpc_queue', body=pickle.dumps(message)) - def train_on_first_layer(self, model, lr, momentum, train_loader=None, local_round=3, layer2_devices=None): - optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) + def train_on_first_layer(self, model, lr, momentum, train_loader=None, local_round=3, layer2_devices=None, model_name=None): + if model_name == 'BERT': + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) + else: + optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) backward_queue_name = f'gradient_queue_{self.layer_id}_{self.client_id}' self.channel.queue_declare(queue=backward_queue_name, durable=False) @@ -81,18 +84,29 @@ def train_on_first_layer(self, model, lr, momentum, train_loader=None, local_rou src.Log.print_with_color(f'Epoch {i}', 'green') with tqdm(total=len(train_loader), desc="Processing", unit="step") as pbar: - for training_data, labels in train_loader: - training_data = training_data.to(self.device) + for batch in train_loader: + if isinstance(batch, dict) and 'input_ids' in batch: + training_data = batch['input_ids'].to(self.device) + attention_mask = batch['attention_mask'].to(self.device) + labels = batch['labels'].to(self.device) + kwargs = {'input_ids': training_data, 'attention_mask': attention_mask} + else: + training_data, labels = batch + training_data = training_data.to(self.device) + labels = labels.to(self.device) + kwargs = {} - # Step 1: Forward data_id = str(uuid.uuid4()) - intermediate_output = model(training_data) + with torch.no_grad(): + if 'input_ids' in kwargs: + intermediate_output = model(**kwargs) + else: + intermediate_output = model(training_data, **kwargs) intermediate_output = intermediate_output.detach().requires_grad_(True) self.data_count += 1 pbar.update(1) - # Step 2: Send smashed data to target layer-2 device (round-robin) target_device_id = None if layer2_devices: target_device_id = layer2_devices[batch_counter % len(layer2_devices)] @@ -100,7 +114,6 @@ def train_on_first_layer(self, model, lr, momentum, train_loader=None, local_rou self.send_intermediate_output(intermediate_output, labels, trace=None, data_id=data_id, target_device_id=target_device_id) - # Step 3: Wait for gradient (blocking) while True: method_frame, header_frame, body = self.channel.basic_get( queue=backward_queue_name, auto_ack=True) @@ -108,10 +121,12 @@ def train_on_first_layer(self, model, lr, momentum, train_loader=None, local_rou received_data = pickle.loads(body) gradient = torch.tensor(received_data["data"]).to(self.device) - # Step 4: Backward model.train() optimizer.zero_grad() - output = model(training_data) + if 'input_ids' in kwargs: + output = model(**kwargs) + else: + output = model(training_data, **kwargs) output.backward(gradient=gradient) optimizer.step() break @@ -134,12 +149,11 @@ def train_on_first_layer(self, model, lr, momentum, train_loader=None, local_rou return True time.sleep(0.5) - def _process_sda_batch(self, model, optimizer, criterion, collected): + def _process_sda_batch(self, model, optimizer, criterion, collected, model_name=None): batch_sizes = [item["data"].shape[0] for item in collected] traces = [item["trace"] for item in collected] data_ids = [item["data_id"] for item in collected] - # Eq. 4: S_c = concat(σ_1, σ_2, ..., σ_|D_c|) all_data = np.concatenate([item["data"] for item in collected], axis=0) all_labels = np.concatenate([item["label"] for item in collected], axis=0) @@ -150,8 +164,10 @@ def _process_sda_batch(self, model, optimizer, criterion, collected): optimizer.zero_grad() concat_intermediate.retain_grad() - # Eq. 5: ŷ = f(S_c | W) - output = model(concat_intermediate) + if model_name == 'BERT': + output = model(input_ids=concat_intermediate) + else: + output = model(concat_intermediate) loss = criterion(output, concat_labels.long()) print(f"Loss (SDA, {len(collected)} clients, {sum(batch_sizes)} samples): {loss.item():.4f}") @@ -174,8 +190,11 @@ def _process_sda_batch(self, model, optimizer, criterion, collected): return result - def train_on_last_layer(self, model, lr, momentum, sda_size=1): - optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) + def train_on_last_layer(self, model, lr, momentum, sda_size=1, model_name=None): + if model_name == 'BERT': + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) + else: + optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) result = True criterion = nn.CrossEntropyLoss() @@ -196,7 +215,7 @@ def train_on_last_layer(self, model, lr, momentum, sda_size=1): # When we have 1 batch from each client → SDA forward if len(sda_batch) >= sda_size: - batch_result = self._process_sda_batch(model, optimizer, criterion, list(sda_batch.values())) + batch_result = self._process_sda_batch(model, optimizer, criterion, list(sda_batch.values()), model_name=model_name) if not batch_result: result = False sda_batch = {} @@ -210,16 +229,16 @@ def train_on_last_layer(self, model, lr, momentum, sda_size=1): if received_data["action"] == "PAUSE": # Process remaining if sda_batch: - batch_result = self._process_sda_batch(model, optimizer, criterion, list(sda_batch.values())) + batch_result = self._process_sda_batch(model, optimizer, criterion, list(sda_batch.values()), model_name=model_name) if not batch_result: result = False return result - def train_on_device(self, model, lr, momentum, train_loader=None, local_round=None, sda_size=1, layer2_devices=None): + def train_on_device(self, model, lr, momentum, train_loader=None, local_round=None, sda_size=1, layer2_devices=None, model_name=None): self.data_count = 0 if self.layer_id == 1: - result = self.train_on_first_layer(model, lr, momentum, train_loader, local_round, layer2_devices=layer2_devices) + result = self.train_on_first_layer(model, lr, momentum, train_loader, local_round, layer2_devices=layer2_devices, model_name=model_name) else: - result = self.train_on_last_layer(model, lr, momentum, sda_size) + result = self.train_on_last_layer(model, lr, momentum, sda_size, model_name=model_name) return result, self.data_count \ No newline at end of file diff --git a/other/DCSL/src/Server.py b/other/DCSL/src/Server.py index 3c36c94..b7c14e7 100644 --- a/other/DCSL/src/Server.py +++ b/other/DCSL/src/Server.py @@ -227,7 +227,7 @@ def on_request(self, ch, method, props, body): if self.validation and self.round_result: state_dict_full = self.concatenate_and_avg_clusters() self.avg_state_dict = [] - if not src.Validation.test(self.model_name, self.data_name, state_dict_full, self.logger): + if not src.Validation.test(self.model_name, self.data_name, state_dict_full, self.logger, self.connection): self.logger.log_warning("Training failed!") else: torch.save(state_dict_full, f'{self.model_name}_{self.data_name}.pth') @@ -239,9 +239,6 @@ def on_request(self, ch, method, props, body): if self.round > 0: current_round = self.global_round - self.round + 1 - # Step decay: reduce LR by lr_decay every lr_step rounds - num_decays = current_round // self.lr_step - self.current_lr = self.lr * (self.lr_decay ** num_decays) self.logger.log_info(f"Start training round {current_round}") self.logger.log_info(f"Learning rate: {self.current_lr}") self.label_ = copy.deepcopy(self.label_counts) @@ -288,7 +285,30 @@ def notify_clients(self, start=True, register=True, idx=0, avg_model=None): keys = state_dict.keys() for key in keys: - state_dict[key] = full_state_dict[key] + if key in full_state_dict: + state_dict[key] = full_state_dict[key] + elif self.model_name == 'BERT': + flex_key = key + if key.startswith('layer1.'): + flex_key = key.replace('layer1.', 'embeddings.') + elif key.startswith('layer14.'): + flex_key = key.replace('layer14.', 'pooler.') + elif key.startswith('layer15.1.'): + flex_key = key.replace('layer15.1.', 'classifier.') + else: + import re + match = re.match(r'layer(\d+)\.(.*)', key) + if match: + layer_idx = int(match.group(1)) + if 2 <= layer_idx <= 13: + flex_key = f'layers.{layer_idx - 2}.{match.group(2)}' + + if flex_key in full_state_dict: + state_dict[key] = full_state_dict[flex_key] + else: + raise KeyError(f"{key} (mapped to {flex_key}) not found in weight file.") + else: + state_dict[key] = full_state_dict[key] self.logger.log_info("Model loaded successfully.") else: self.logger.log_info(f"File {filepath} does not exist.") diff --git a/other/DCSL/src/Validation.py b/other/DCSL/src/Validation.py index a5693f1..a4cc502 100644 --- a/other/DCSL/src/Validation.py +++ b/other/DCSL/src/Validation.py @@ -9,7 +9,7 @@ from src.model import * -def test(model_name, data_name, state_dict_full, logger): +def test(model_name, data_name, state_dict_full, logger, server_connection=None): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if data_name == "MNIST": transform_test = transforms.Compose([ @@ -26,15 +26,17 @@ def test(model_name, data_name, state_dict_full, logger): testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) elif data_name == "SPEECHCOMMANDS": from src.dataset.SPEECHCOMMANDS import SpeechCommandsDataset - testset = SpeechCommandsDataset(root='./data', subset='testing') - elif data_name == "EMOTION": + testset_full = SpeechCommandsDataset(root='./data', subset='testing') + indices = np.random.choice(len(testset_full), 5000, replace=False) + testset = torch.utils.data.Subset(testset_full, indices) + elif data_name == "AGNEWS": from datasets import load_dataset from transformers import BertTokenizer - from src.dataset.EMOTION import EMOTIONDataset, load_test_EMOTION + from src.dataset.AGNEWS import AGNEWS_DATASET, load_test_AGNEWS dataset = load_dataset('ag_news', download_mode='reuse_dataset_if_exists', cache_dir='./hf_cache') - tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') - test_texts, test_labels = load_test_EMOTION(1000, dataset) - testset = EMOTIONDataset(test_texts, test_labels, tokenizer, max_length=128) + tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + test_texts, test_labels = load_test_AGNEWS(1000, dataset) + testset = AGNEWS_DATASET(test_texts, test_labels, tokenizer, max_length=128) else: raise ValueError(f"Data name '{data_name}' is not valid.") @@ -58,15 +60,28 @@ def test(model_name, data_name, state_dict_full, logger): total = 0 with torch.no_grad(): - for data, target in tqdm(test_loader): - data = data.to(device) - target = target.to(device) - output = model(data) + for i, batch in enumerate(tqdm(test_loader)): + if isinstance(batch, dict) and 'input_ids' in batch: + data = batch['input_ids'].to(device) + target = batch['labels'].to(device) + output = model(data, attention_mask=batch['attention_mask'].to(device)) + else: + data, target = batch + data = data.to(device) + target = target.to(device) + output = model(data) + loss = criterion(output, target) test_loss += loss.item() * target.size(0) pred = output.argmax(dim=1) correct += pred.eq(target).sum().item() total += target.size(0) + + if server_connection is not None and i % 5 == 0: + try: + server_connection.process_data_events(time_limit=0) + except Exception: + pass test_loss /= total accuracy = 100.0 * correct / total diff --git a/other/DCSL/src/dataset/EMOTION.py b/other/DCSL/src/dataset/AGNEWS.py similarity index 51% rename from other/DCSL/src/dataset/EMOTION.py rename to other/DCSL/src/dataset/AGNEWS.py index b24b487..2f73a74 100644 --- a/other/DCSL/src/dataset/EMOTION.py +++ b/other/DCSL/src/dataset/AGNEWS.py @@ -1,8 +1,6 @@ import torch -import random -from collections import defaultdict -class EMOTIONDataset(torch.utils.data.Dataset): +class AGNEWS_DATASET(torch.utils.data.Dataset): def __init__(self, texts, labels, tokenizer, max_length=128): self.texts = texts self.labels = labels @@ -25,40 +23,23 @@ def __getitem__(self, idx): return_tensors='pt' ) - return encoding['input_ids'].flatten(), torch.tensor(label, dtype=torch.long) - -def load_train_EMOTION(dataset=None, distribution=None): - random.seed(1) - - train_data = dataset['train'] - - train_target_counts = {k: v for k, v in enumerate(distribution)} - - train_by_class = defaultdict(list) - for text, label in zip(train_data['text'], train_data['label']): - train_by_class[label].append((text, label)) + return { + 'input_ids': encoding['input_ids'].flatten(), + 'attention_mask': encoding['attention_mask'].flatten(), + 'labels': torch.tensor(label, dtype=torch.long) + } - train_texts, train_labels = [], [] - for label, count in train_target_counts.items(): - samples = random.sample(train_by_class[label], count) - train_texts.extend([t for t, _ in samples]) - train_labels.extend([l for _, l in samples]) - - print("Train samples:", len(train_texts), {l: train_labels.count(l) for l in set(train_labels)}) - - return train_texts, train_labels +from collections import defaultdict +import random -def load_test_EMOTION(test_total=1000, dataset=None): - random.seed(1) +def load_test_AGNEWS(num_samples, dataset): test_data = dataset['test'] - class_distribution = { - 0: 0.25, - 1: 0.25, - 2: 0.25, - 3: 0.25 - } - test_target_counts = {k: int(v * test_total) for k, v in class_distribution.items()} + + # AGNEWS có 4 class, chia đều theo num_samples + distribution = [num_samples // 4] * 4 + test_target_counts = {k: v for k, v in enumerate(distribution)} test_by_class = defaultdict(list) + for text, label in zip(test_data['text'], test_data['label']): test_by_class[label].append((text, label)) @@ -69,5 +50,4 @@ def load_test_EMOTION(test_total=1000, dataset=None): test_labels.extend([l for _, l in samples]) print("Test samples:", len(test_texts), {l: test_labels.count(l) for l in set(test_labels)}) - return test_texts, test_labels \ No newline at end of file diff --git a/other/DCSL/src/dataset/SPEECHCOMMANDS.py b/other/DCSL/src/dataset/SPEECHCOMMANDS.py index b50f27e..41296f0 100644 --- a/other/DCSL/src/dataset/SPEECHCOMMANDS.py +++ b/other/DCSL/src/dataset/SPEECHCOMMANDS.py @@ -6,7 +6,6 @@ from scipy.io import wavfile from scipy.fftpack import dct -# 12 classes standard (10 keywords + silence + unknown) CLASSES = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go', 'silence', 'unknown'] def compute_mfcc(waveform, sample_rate=16000, n_mfcc=40, n_fft=480, hop_length=160, n_mels=40): @@ -87,13 +86,6 @@ class SpeechCommandsDataset(Dataset): """Google Speech Commands V2 dataset with MFCC features for KWT model""" def __init__(self, root='./data', subset='training', n_silence=2300): - """ - Args: - root: Data directory - subset: 'training', 'validation', or 'testing' - n_silence: Number of silence samples to generate (training). - For val/test, uses n_silence // 9 (~260). - """ self.root = os.path.join(root, 'SpeechCommands', 'speech_commands_v0.02') self.subset = subset self.samples = [] # list of (path_or_None, label_str) diff --git a/other/DCSL/src/model/BERT_AGNEWS.py b/other/DCSL/src/model/BERT_AGNEWS.py new file mode 100644 index 0000000..e5f2fcf --- /dev/null +++ b/other/DCSL/src/model/BERT_AGNEWS.py @@ -0,0 +1,221 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class DotDict(dict): + def __getattr__(self, k): + try: return self[k] + except KeyError: raise AttributeError(k) + def __setattr__(self, k, v): self[k] = v + def __delattr__(self, k): del self[k] + +# AG_NEWS have 4 layers: World, Sports, Business, Sci/Tech +class BertEmbeddings(nn.Module): + def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, dropout_prob): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) + self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, input_ids, token_type_ids=None): + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + +class BertSdpaSelfAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, dropout_prob): + super(BertSdpaSelfAttention, self).__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(hidden_size, self.all_head_size) + self.key = nn.Linear(hidden_size, self.all_head_size) + self.value = nn.Linear(hidden_size, self.all_head_size) + self.dropout = nn.Dropout(dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + import math + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + +class BertSelfOutput(nn.Module): + def __init__(self, hidden_size, dropout_prob): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + +class BertAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, dropout_prob): + super(BertAttention, self).__init__() + self.self = BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob) + self.output = BertSelfOutput(hidden_size, dropout_prob) + + def forward(self, hidden_states): + self_output = self.self(hidden_states) + attention_output = self.output(self_output, hidden_states) + return attention_output + +class BertIntermediate(nn.Module): + def __init__(self, hidden_size, intermediate_size): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear(hidden_size, intermediate_size) + self.intermediate_act_fn = nn.GELU() + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + +class BertOutput(nn.Module): + def __init__(self, hidden_size, intermediate_size, dropout_prob): + super(BertOutput, self).__init__() + self.dense = nn.Linear(intermediate_size, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + +class BertLayer(nn.Module): + def __init__(self, hidden_size, num_attention_heads, intermediate_size, dropout_prob): + super(BertLayer, self).__init__() + self.attention = BertAttention(hidden_size, num_attention_heads, dropout_prob) + self.intermediate = BertIntermediate(hidden_size, intermediate_size) + self.output = BertOutput(hidden_size, intermediate_size, dropout_prob) + + def forward(self, hidden_states): + attention_output = self.attention(hidden_states) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + +class BertPooler(nn.Module): + def __init__(self, hidden_size): + super(BertPooler, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + +class BertClassifier(nn.Module): + def __init__(self, hidden_size, num_labels, dropout_prob=0.1): + super(BertClassifier, self).__init__() + self.dropout = nn.Dropout(dropout_prob) + self.classifier = nn.Linear(hidden_size, num_labels) + + def forward(self, pooled_output): + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + return logits + +class BERT_AGNEWS(nn.Module): + def __init__(self, vocab_size=28996, hidden_size=768, num_attention_heads=12, intermediate_size=3072, + max_position_embeddings=512, type_vocab_size=2, dropout_prob=0.1, n_block=12, + start_layer=0, end_layer=15): + super(BERT_AGNEWS, self).__init__() + self.start_layer = start_layer + self.end_layer = 15 if end_layer == -1 else end_layer + self.config = DotDict( + model_type="bert", + vocab_size=vocab_size, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, intermediate_size=intermediate_size, + max_position_embeddings=max_position_embeddings, + bos_token_id=101, eos_token_id=102, pad_token_id=0, + is_encoder_decoder=False, tie_word_embeddings=False, + use_return_dict=True, output_attentions=False, output_hidden_states=False + ) + + if self.start_layer < 1 <= self.end_layer: + self.layer1 = BertEmbeddings(vocab_size=vocab_size, hidden_size=hidden_size, + max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size, dropout_prob=dropout_prob) + + for i in range(12): + layer_idx = 2 + i + if self.start_layer < layer_idx <= self.end_layer: + setattr(self, f'layer{layer_idx}', + BertLayer(hidden_size, num_attention_heads, intermediate_size, dropout_prob)) + + if self.start_layer < 14 <= self.end_layer: + self.layer14 = BertPooler(hidden_size) + + if self.start_layer < 15 <= self.end_layer: + self.layer15 = nn.Sequential( + nn.Dropout(dropout_prob), + nn.Linear(hidden_size, 4) + ) + + def forward(self, input_ids=None, token_type_ids=None, **kwargs): + x = input_ids + if self.start_layer < 1 <= self.end_layer: + x = self.layer1(x, token_type_ids) + + for i in range(12): + layer_idx = 2 + i + if self.start_layer < layer_idx <= self.end_layer: + layer_module = getattr(self, f'layer{layer_idx}') + x = layer_module(x) + + if self.start_layer < 14 <= self.end_layer: + x = self.layer14(x) + + if self.start_layer < 15 <= self.end_layer: + x = self.layer15(x) + + return x \ No newline at end of file diff --git a/other/DCSL/src/model/BERT_EMOTION.py b/other/DCSL/src/model/BERT_EMOTION.py deleted file mode 100644 index efb4ab9..0000000 --- a/other/DCSL/src/model/BERT_EMOTION.py +++ /dev/null @@ -1,428 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -# Configuration constants -NUM_SAMPLES = 5000 -num_labels = 6 -vocab_size = 30522 -hidden_size = 768 -num_hidden_layers = 12 -num_attention_heads = 12 -intermediate_size = 3072 -max_position_embeddings = 512 -type_vocab_size = 2 -dropout_prob = 0.1 - -# BertEmbeddings class -class BertEmbeddings(nn.Module): - def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, dropout_prob): - super(BertEmbeddings, self).__init__() - self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=0) - self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) - self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) - self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12) - self.dropout = nn.Dropout(dropout_prob) - - def forward(self, input_ids, token_type_ids=None): - seq_length = input_ids.size(1) - position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(input_ids) - - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - words_embeddings = self.word_embeddings(input_ids) - position_embeddings = self.position_embeddings(position_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - - embeddings = words_embeddings + position_embeddings + token_type_embeddings - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings) - return embeddings - -# BertSdpaSelfAttention class -class BertSdpaSelfAttention(nn.Module): - def __init__(self, hidden_size, num_attention_heads, dropout_prob): - super(BertSdpaSelfAttention, self).__init__() - self.num_attention_heads = num_attention_heads - self.attention_head_size = int(hidden_size / num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear(hidden_size, self.all_head_size) - self.key = nn.Linear(hidden_size, self.all_head_size) - self.value = nn.Linear(hidden_size, self.all_head_size) - self.dropout = nn.Dropout(dropout_prob) - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward(self, hidden_states, attention_mask=None): - batch_size, seq_length, hidden_size = hidden_states.size() - - # Create query, key, value projections - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) - - # Reshape for multi-head attention - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) - - # Perform attention score calculation - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - - # Scale attention scores - import math - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - - # Apply attention mask if provided - if attention_mask is not None: - # Reshape attention_mask from [batch_size, seq_length] to [batch_size, 1, 1, seq_length] - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - # Convert 1s (valid tokens) to 0s and 0s (padding) to large negative values - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - attention_scores = attention_scores + extended_attention_mask - - # Apply softmax to get probabilities - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = self.dropout(attention_probs) - - # Calculate context by attending to values - context_layer = torch.matmul(attention_probs, value_layer) - - # Reshape back to [batch_size, seq_length, hidden_size] - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) - - return context_layer - -# BertSelfOutput class -class BertSelfOutput(nn.Module): - def __init__(self, hidden_size, dropout_prob): - super(BertSelfOutput, self).__init__() - self.dense = nn.Linear(hidden_size, hidden_size) - self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12) - self.dropout = nn.Dropout(dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - -# BertAttention class -class BertAttention(nn.Module): - def __init__(self, hidden_size, num_attention_heads, dropout_prob): - super(BertAttention, self).__init__() - self.self = BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob) - self.output = BertSelfOutput(hidden_size, dropout_prob) - - def forward(self, hidden_states, attention_mask=None): - self_output = self.self(hidden_states, attention_mask) - attention_output = self.output(self_output, hidden_states) - return attention_output - -# BertIntermediate class -class BertIntermediate(nn.Module): - def __init__(self, hidden_size, intermediate_size): - super(BertIntermediate, self).__init__() - self.dense = nn.Linear(hidden_size, intermediate_size) - self.intermediate_act_fn = nn.GELU() - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - return hidden_states - -# BertOutput class -class BertOutput(nn.Module): - def __init__(self, hidden_size, intermediate_size, dropout_prob): - super(BertOutput, self).__init__() - self.dense = nn.Linear(intermediate_size, hidden_size) - self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12) - self.dropout = nn.Dropout(dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -# BertPooler class -class BertPooler(nn.Module): - def __init__(self, hidden_size): - super(BertPooler, self).__init__() - self.dense = nn.Linear(hidden_size, hidden_size) - self.activation = nn.Tanh() - - def forward(self, hidden_states): - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output - -# BertClassifier class -class BertClassifier(nn.Module): - def __init__(self, hidden_size, num_labels, dropout_prob=0.1): - super(BertClassifier, self).__init__() - self.dropout = nn.Dropout(dropout_prob) - self.classifier = nn.Linear(hidden_size, num_labels) - - def forward(self, pooled_output): - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - return logits - -# Complete BERT Model -class BERT_EMOTION(nn.Module): - def __init__(self,start_layer= 0, end_layer= 27, vocab_size=30522, hidden_size=768, intermediate_size=3072, - num_attention_heads=12, num_labels=4, max_position_embeddings=512, - type_vocab_size=2, dropout_prob=0.1, num_hidden_layers=12): - - super(BERT_EMOTION, self).__init__() - - self.start_layer = start_layer - self.end_layer = end_layer - - if (self.start_layer < 1) and (self.end_layer >= 1): - self.layer1 = BertEmbeddings(vocab_size, hidden_size , max_position_embeddings, type_vocab_size, dropout_prob) - - if (self.start_layer < 2) and (self.end_layer >= 2): - self.layer2 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 3) and (self.end_layer >= 3): - self.layer3 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 4) and (self.end_layer >= 4): - self.layer4 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 5) and (self.end_layer >= 5): - self.layer5 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 6) and (self.end_layer >= 6): - self.layer6 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 7) and (self.end_layer >= 7): - self.layer7 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 8) and (self.end_layer >= 8): - self.layer8 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 9) and (self.end_layer >= 9): - self.layer9 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 10) and (self.end_layer >= 10): - self.layer10 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 11) and (self.end_layer >= 11): - self.layer11 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 12) and (self.end_layer >= 12): - self.layer12 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 13) and (self.end_layer >= 13): - self.layer13 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 14) and (self.end_layer >= 14): - self.layer14 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 15) and (self.end_layer >= 15): - self.layer15 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 16) and (self.end_layer >= 16): - self.layer16 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 17) and (self.end_layer >= 17): - self.layer17 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 18) and (self.end_layer >= 18): - self.layer18 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 19) and (self.end_layer >= 19): - self.layer19 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 20) and (self.end_layer >= 20): - self.layer20 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 21) and (self.end_layer >= 21): - self.layer21 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 22) and (self.end_layer >= 22): - self.layer22 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 23) and (self.end_layer >= 23): - self.layer23 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 24) and (self.end_layer >= 24): - self.layer24 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 25) and (self.end_layer >= 25): - self.layer25 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 26) and (self.end_layer >= 26): - self.layer26 = BertPooler(hidden_size) - - if (self.start_layer < 27) and (self.end_layer >= 27): - self.layer27 = BertClassifier(hidden_size, num_labels, dropout_prob) - - def forward(self, x, attention_mask=None, token_type_ids=None): - if (self.start_layer < 1) and (self.end_layer >= 1): - x = self.layer1(x, token_type_ids) - - if (self.start_layer < 2) and (self.end_layer >= 2): - x = self.layer2[1](self.layer2[0](x, attention_mask), x) - - if (self.start_layer < 3) and (self.end_layer >= 3): - x = self.layer3[1](self.layer3[0](x), x) - - if (self.start_layer < 4) and (self.end_layer >= 4): - x = self.layer4[1](self.layer4[0](x, attention_mask), x) - - if (self.start_layer < 5) and (self.end_layer >= 5): - x = self.layer5[1](self.layer5[0](x), x) - - if (self.start_layer < 6) and (self.end_layer >= 6): - x = self.layer6[1](self.layer6[0](x, attention_mask), x) - - if (self.start_layer < 7) and (self.end_layer >= 7): - x = self.layer7[1](self.layer7[0](x), x) - - if (self.start_layer < 8) and (self.end_layer >= 8): - x = self.layer8[1](self.layer8[0](x, attention_mask), x) - - if (self.start_layer < 9) and (self.end_layer >= 9): - x = self.layer9[1](self.layer9[0](x), x) - - if (self.start_layer < 10) and (self.end_layer >= 10): - x = self.layer10[1](self.layer10[0](x, attention_mask), x) - - if (self.start_layer < 11) and (self.end_layer >= 11): - x = self.layer11[1](self.layer11[0](x), x) - - if (self.start_layer < 12) and (self.end_layer >= 12): - x = self.layer12[1](self.layer12[0](x, attention_mask), x) - - if (self.start_layer < 13) and (self.end_layer >= 13): - x = self.layer13[1](self.layer13[0](x), x) - - if (self.start_layer < 14) and (self.end_layer >= 14): - x = self.layer14[1](self.layer14[0](x, attention_mask), x) - - if (self.start_layer < 15) and (self.end_layer >= 15): - x = self.layer15[1](self.layer15[0](x), x) - - if (self.start_layer < 16) and (self.end_layer >= 16): - x = self.layer16[1](self.layer16[0](x, attention_mask), x) - - if (self.start_layer < 17) and (self.end_layer >= 17): - x = self.layer17[1](self.layer17[0](x), x) - - if (self.start_layer < 18) and (self.end_layer >= 18): - x = self.layer18[1](self.layer18[0](x, attention_mask), x) - - if (self.start_layer < 19) and (self.end_layer >= 19): - x = self.layer19[1](self.layer19[0](x), x) - - if (self.start_layer < 20) and (self.end_layer >= 20): - x = self.layer20[1](self.layer20[0](x, attention_mask), x) - - if (self.start_layer < 21) and (self.end_layer >= 21): - x = self.layer21[1](self.layer21[0](x), x) - - if (self.start_layer < 22) and (self.end_layer >= 22): - x = self.layer22[1](self.layer22[0](x, attention_mask), x) - - if (self.start_layer < 23) and (self.end_layer >= 23): - x = self.layer23[1](self.layer23[0](x), x) - - if (self.start_layer < 24) and (self.end_layer >= 24): - x = self.layer24[1](self.layer24[0](x, attention_mask), x) - - if (self.start_layer < 25) and (self.end_layer >= 25): - x = self.layer25[1](self.layer25[0](x), x) - - if (self.start_layer < 26) and (self.end_layer >= 26): - x = self.layer26(x) - - if (self.start_layer < 27) and (self.end_layer >= 27): - x = self.layer27(x) - - return x diff --git a/other/DCSL/src/model/KWT_SPEECHCOMMANDS.py b/other/DCSL/src/model/KWT_SPEECHCOMMANDS.py index 4a53f33..4f716b1 100644 --- a/other/DCSL/src/model/KWT_SPEECHCOMMANDS.py +++ b/other/DCSL/src/model/KWT_SPEECHCOMMANDS.py @@ -27,9 +27,7 @@ def forward(self, x): class KWT_SPEECHCOMMANDS(nn.Module): """ - Keyword Transformer (KWT-1) for Speech Commands - Split Learning version KWT-1: dim=64, mlp_dim=256, heads=1, layers=12 - Layers: 1: Linear embedding (n_mfcc -> embed_dim) 2: CLS token concatenation @@ -42,12 +40,10 @@ class KWT_SPEECHCOMMANDS(nn.Module): def __init__(self, start_layer=0, end_layer=17): super().__init__() self.start_layer = start_layer - # Handle -1 as "all remaining layers" self.end_layer = 17 if end_layer == -1 else end_layer - # KWT-1 config (matches Colab) n_mfcc = 40 - time_steps = 98 # (16000 - 480) / 160 + 1 ≈ 98 + time_steps = 98 embed_dim = 64 num_heads = 1 mlp_dim = 256 @@ -91,17 +87,16 @@ def _init_weights(self): nn.init.trunc_normal_(self.pos_embed, std=0.02) def forward(self, x): - # Input x shape: (batch, n_mfcc, time_steps) # Layer 1: Linear embedding if self.start_layer < 1 <= self.end_layer: - x = x.transpose(1, 2) # (batch, time_steps, n_mfcc) - x = self.layer1(x) # (batch, time_steps, embed_dim) + x = x.transpose(1, 2) + x = self.layer1(x) # Layer 2: CLS token if self.start_layer < 2 <= self.end_layer: cls_token = self.cls_token.expand(x.size(0), -1, -1) - x = torch.cat([cls_token, x], dim=1) # (batch, time_steps+1, embed_dim) + x = torch.cat([cls_token, x], dim=1) # Layer 3: Positional embedding + Dropout if self.start_layer < 3 <= self.end_layer: diff --git a/other/DCSL/src/model/__init__.py b/other/DCSL/src/model/__init__.py index 8353b7e..b1922ce 100644 --- a/other/DCSL/src/model/__init__.py +++ b/other/DCSL/src/model/__init__.py @@ -4,5 +4,5 @@ from .VGG16_MNIST import * from .ViT_CIFAR10 import * from .ViT_MNIST import * -from .BERT_EMOTION import * +from .BERT_AGNEWS import * from .KWT_SPEECHCOMMANDS import * From fd0852d43cd6aca22f83bb602c01dc654f881a16 Mon Sep 17 00:00:00 2001 From: truongtruong373 Date: Thu, 26 Mar 2026 10:11:54 +0700 Subject: [PATCH 4/4] fix : fine-tuning Bert with AGNEWS --- other/DCSL/src/Log.py | 6 - other/DCSL/src/RpcClient.py | 81 +------------- other/DCSL/src/Server.py | 24 +--- other/DCSL/src/Utils.py | 12 -- other/DCSL/src/Validation.py | 36 +----- other/DCSL/src/dataset/AGNEWS.py | 25 +---- other/DCSL/src/dataset/SPEECHCOMMANDS.py | 101 +++++++---------- other/DCSL/src/dataset/dataloader.py | 135 +++++++++++++++++++++++ other/DCSL/src/model/BERT_AGNEWS.py | 5 +- 9 files changed, 187 insertions(+), 238 deletions(-) create mode 100644 other/DCSL/src/dataset/dataloader.py diff --git a/other/DCSL/src/Log.py b/other/DCSL/src/Log.py index 640ad13..f8f7eac 100644 --- a/other/DCSL/src/Log.py +++ b/other/DCSL/src/Log.py @@ -1,6 +1,5 @@ import logging - class Colors: COLORS = { "header": '\033[95m', @@ -11,23 +10,18 @@ class Colors: "end": '\033[0m' } - class Logger: def __init__(self, log_path, debug_mode=False): - # Thiết lập logger với tên "my_logger" self.logger = logging.getLogger("my_logger") self.logger.setLevel(logging.DEBUG) # Mức log self.debug_mode = debug_mode - # Tạo file handler để ghi log vào file file_handler = logging.FileHandler(log_path) file_handler.setLevel(logging.DEBUG) - # Định dạng log formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler.setFormatter(formatter) - # Gắn file handler vào logger self.logger.addHandler(file_handler) def log_info(self, message): diff --git a/other/DCSL/src/RpcClient.py b/other/DCSL/src/RpcClient.py index 37b4a5d..51f2b98 100644 --- a/other/DCSL/src/RpcClient.py +++ b/other/DCSL/src/RpcClient.py @@ -1,15 +1,11 @@ import time import pickle -import random import copy -import torchvision -import torchvision.transforms as transforms - -from collections import defaultdict -from tqdm import tqdm import src.Log + from src.model import * +from src.dataset.dataloader import data_loader from peft import LoraConfig, get_peft_model @@ -24,6 +20,7 @@ def __init__(self, client_id, layer_id, channel, train_func, device): self.response = None self.model = None + self.train_loader = None self.label_count = None self.peft_config = None @@ -59,62 +56,6 @@ def response_message(self, body): if self.label_count is not None: src.Log.print_with_color(f"Label distribution of client: {self.label_count}", "yellow") - # Load training dataset - if self.layer_id == 1 and data_name and not self.train_set and not self.label_to_indices: - if data_name == "MNIST": - transform_train = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) - ]) - self.train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, - transform=transform_train) - - elif data_name == "CIFAR10": - transform_train = transforms.Compose([ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - self.train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, - transform=transform_train) - elif data_name == "SPEECHCOMMANDS": - from src.dataset.SPEECHCOMMANDS import SpeechCommandsDataset - self.train_set = SpeechCommandsDataset(root='./data', subset='training') - elif data_name == "AGNEWS": - from datasets import load_dataset - from transformers import BertTokenizer - from src.dataset.AGNEWS import AGNEWS_DATASET - - dataset = load_dataset('ag_news', download_mode='reuse_dataset_if_exists', cache_dir='./hf_cache') - tokenizer = BertTokenizer.from_pretrained('bert-base-cased') - - train_data = dataset['train'] - texts = train_data['text'] - labels = train_data['label'] - - self.train_set = AGNEWS_DATASET(texts, labels, tokenizer, max_length=128) - else: - self.train_set = None - raise ValueError(f"Data name '{data_name}' is not valid.") - - self.label_to_indices = defaultdict(list) - if data_name == "AGNEWS": - for idx, label in enumerate(self.train_set.labels): - self.label_to_indices[int(label)].append(idx) - elif data_name == "SPEECHCOMMANDS": - from src.dataset.SPEECHCOMMANDS import CLASSES - for idx, (audio_path, label_name) in enumerate(self.train_set.samples): - if label_name in CLASSES: - label_idx = CLASSES.index(label_name) - else: - label_idx = CLASSES.index('unknown') - self.label_to_indices[label_idx].append(idx) - else: - for idx, (_, label) in tqdm(enumerate(self.train_set)): - self.label_to_indices[int(label)].append(idx) - - # Load model if self.model is None: klass = globals()[f'{model_name}_{data_name}'] @@ -132,11 +73,9 @@ def response_message(self, body): sda_size = self.response.get("sda_size", 1) layer2_devices = self.response.get("layer2_devices", []) - # Read parameters and load to model if state_dict: self.model.load_state_dict(state_dict) - # Apply LoRA for BERT model if model_name == 'BERT': if self.peft_config is None: self.peft_config = LoraConfig( @@ -152,25 +91,18 @@ def response_message(self, body): self.model.to(self.device) - # Start training if self.layer_id == 1: - selected_indices = [] - for label, count in enumerate(self.label_count): - selected_indices.extend(random.sample(self.label_to_indices[label], count)) - - subset = torch.utils.data.Subset(self.train_set, selected_indices) - train_loader = torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=True) + if self.train_loader is None: + self.train_loader = data_loader(data_name, batch_size, self.label_count, train=True) - result, size = self.train_func(self.model, lr, momentum, train_loader, local_round=local_round, layer2_devices=layer2_devices, model_name=model_name) + result, size = self.train_func(self.model, lr, momentum, self.train_loader, local_round=local_round, layer2_devices=layer2_devices, model_name=model_name) else: result, size = self.train_func(self.model, lr, momentum, None, local_round=local_round, sda_size=sda_size, model_name=model_name) - # Merge LoRA weights back for BERT if model_name == 'BERT': self.model = self.model.merge_and_unload() - # Stop training, then send parameters to server model_state_dict = copy.deepcopy(self.model.state_dict()) if self.device != "cpu": for key in model_state_dict: @@ -185,7 +117,6 @@ def response_message(self, body): elif action == "STOP": return False - def send_to_server(self, message): self.response = None diff --git a/other/DCSL/src/Server.py b/other/DCSL/src/Server.py index b7c14e7..94f9f29 100644 --- a/other/DCSL/src/Server.py +++ b/other/DCSL/src/Server.py @@ -76,6 +76,9 @@ def __init__(self, config): self.idx = 0 self.current_clients_cluster = 0 + self.current_lr = 0 + self.sda_size = 0 + self.channel.basic_qos(prefetch_count=1) self.reply_channel = self.connection.channel() self.channel.basic_consume(queue='rpc_queue', on_message_callback=self.on_request) @@ -229,6 +232,7 @@ def on_request(self, ch, method, props, body): self.avg_state_dict = [] if not src.Validation.test(self.model_name, self.data_name, state_dict_full, self.logger, self.connection): self.logger.log_warning("Training failed!") + self.round = 0 else: torch.save(state_dict_full, f'{self.model_name}_{self.data_name}.pth') self.round -= 1 @@ -287,26 +291,6 @@ def notify_clients(self, start=True, register=True, idx=0, avg_model=None): for key in keys: if key in full_state_dict: state_dict[key] = full_state_dict[key] - elif self.model_name == 'BERT': - flex_key = key - if key.startswith('layer1.'): - flex_key = key.replace('layer1.', 'embeddings.') - elif key.startswith('layer14.'): - flex_key = key.replace('layer14.', 'pooler.') - elif key.startswith('layer15.1.'): - flex_key = key.replace('layer15.1.', 'classifier.') - else: - import re - match = re.match(r'layer(\d+)\.(.*)', key) - if match: - layer_idx = int(match.group(1)) - if 2 <= layer_idx <= 13: - flex_key = f'layers.{layer_idx - 2}.{match.group(2)}' - - if flex_key in full_state_dict: - state_dict[key] = full_state_dict[flex_key] - else: - raise KeyError(f"{key} (mapped to {flex_key}) not found in weight file.") else: state_dict[key] = full_state_dict[key] self.logger.log_info("Model loaded successfully.") diff --git a/other/DCSL/src/Utils.py b/other/DCSL/src/Utils.py index dc42a72..2c7c815 100644 --- a/other/DCSL/src/Utils.py +++ b/other/DCSL/src/Utils.py @@ -1,5 +1,3 @@ -import numpy as np -import random import pika import torch @@ -34,12 +32,6 @@ def delete_old_queues(address, username, password, virtual_host): return False def fedavg_state_dicts(state_dicts, weights = None): - """ - Trung bình (FedAvg) một list các state_dict. - - state_dicts: list các dict {param_name: tensor} - - weights: list trọng số tương ứng (mặc định None nghĩa là mỗi model weight=1) - Trả về một dict {param_name: tensor_avg} - """ num = len(state_dicts) if num == 0: raise ValueError("fedavg_state_dicts: không có state_dict nào để trung bình.") @@ -48,12 +40,10 @@ def fedavg_state_dicts(state_dicts, weights = None): weights = [1.0] * num total_w = sum(weights) - # Tập hợp tất cả key all_keys = set().union(*(sd.keys() for sd in state_dicts)) avg_dict = {} for key in all_keys: - # gom tensor + weight, xử lý NaN acc = None for sd, w in zip(state_dicts, weights): if key not in sd: @@ -64,10 +54,8 @@ def fedavg_state_dicts(state_dicts, weights = None): t = t * w acc = t if acc is None else acc + t - # chia trung bình avg = acc / total_w - # cast về dtype gốc orig = next(sd[key] for sd in state_dicts if key in sd) if orig.dtype in (torch.int8, torch.int16, torch.int32, torch.int64, torch.bool): avg = avg.round().to(orig.dtype) diff --git a/other/DCSL/src/Validation.py b/other/DCSL/src/Validation.py index a4cc502..ebc9a25 100644 --- a/other/DCSL/src/Validation.py +++ b/other/DCSL/src/Validation.py @@ -1,46 +1,14 @@ - import numpy as np import math from tqdm import tqdm -import torchvision -import torchvision.transforms as transforms -import torch.nn as nn - +from src.dataset.dataloader import data_loader from src.model import * def test(model_name, data_name, state_dict_full, logger, server_connection=None): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if data_name == "MNIST": - transform_test = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) - ]) - testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test) - - elif data_name == "CIFAR10": - transform_test = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) - elif data_name == "SPEECHCOMMANDS": - from src.dataset.SPEECHCOMMANDS import SpeechCommandsDataset - testset_full = SpeechCommandsDataset(root='./data', subset='testing') - indices = np.random.choice(len(testset_full), 5000, replace=False) - testset = torch.utils.data.Subset(testset_full, indices) - elif data_name == "AGNEWS": - from datasets import load_dataset - from transformers import BertTokenizer - from src.dataset.AGNEWS import AGNEWS_DATASET, load_test_AGNEWS - dataset = load_dataset('ag_news', download_mode='reuse_dataset_if_exists', cache_dir='./hf_cache') - tokenizer = BertTokenizer.from_pretrained('bert-base-cased') - test_texts, test_labels = load_test_AGNEWS(1000, dataset) - testset = AGNEWS_DATASET(test_texts, test_labels, tokenizer, max_length=128) - else: - raise ValueError(f"Data name '{data_name}' is not valid.") - test_loader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) + test_loader = data_loader(data_name=data_name, train=False) criterion = nn.CrossEntropyLoss() diff --git a/other/DCSL/src/dataset/AGNEWS.py b/other/DCSL/src/dataset/AGNEWS.py index 2f73a74..4fc37dc 100644 --- a/other/DCSL/src/dataset/AGNEWS.py +++ b/other/DCSL/src/dataset/AGNEWS.py @@ -27,27 +27,4 @@ def __getitem__(self, idx): 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'labels': torch.tensor(label, dtype=torch.long) - } - -from collections import defaultdict -import random - -def load_test_AGNEWS(num_samples, dataset): - test_data = dataset['test'] - - # AGNEWS có 4 class, chia đều theo num_samples - distribution = [num_samples // 4] * 4 - test_target_counts = {k: v for k, v in enumerate(distribution)} - test_by_class = defaultdict(list) - - for text, label in zip(test_data['text'], test_data['label']): - test_by_class[label].append((text, label)) - - test_texts, test_labels = [], [] - for label, count in test_target_counts.items(): - samples = random.sample(test_by_class[label], count) - test_texts.extend([t for t, _ in samples]) - test_labels.extend([l for _, l in samples]) - - print("Test samples:", len(test_texts), {l: test_labels.count(l) for l in set(test_labels)}) - return test_texts, test_labels \ No newline at end of file + } \ No newline at end of file diff --git a/other/DCSL/src/dataset/SPEECHCOMMANDS.py b/other/DCSL/src/dataset/SPEECHCOMMANDS.py index 41296f0..4caf008 100644 --- a/other/DCSL/src/dataset/SPEECHCOMMANDS.py +++ b/other/DCSL/src/dataset/SPEECHCOMMANDS.py @@ -10,30 +10,25 @@ def compute_mfcc(waveform, sample_rate=16000, n_mfcc=40, n_fft=480, hop_length=160, n_mels=40): """Compute MFCC features using scipy (no torchaudio backend needed)""" - # Pre-emphasis emphasized = np.append(waveform[0], waveform[1:] - 0.97 * waveform[:-1]) - - # Framing + frame_length = n_fft num_frames = 1 + (len(emphasized) - frame_length) // hop_length frames = np.zeros((num_frames, frame_length)) for i in range(num_frames): - frames[i] = emphasized[i * hop_length : i * hop_length + frame_length] - - # Windowing + frames[i] = emphasized[i * hop_length: i * hop_length + frame_length] + frames *= np.hamming(frame_length) - - # FFT + mag_frames = np.absolute(np.fft.rfft(frames, n_fft)) pow_frames = (1.0 / n_fft) * (mag_frames ** 2) - - # Mel filterbank + low_freq_mel = 0 high_freq_mel = 2595 * np.log10(1 + (sample_rate / 2) / 700) mel_points = np.linspace(low_freq_mel, high_freq_mel, n_mels + 2) hz_points = 700 * (10 ** (mel_points / 2595) - 1) bin_points = np.floor((n_fft + 1) * hz_points / sample_rate).astype(int) - + fbank = np.zeros((n_mels, int(np.floor(n_fft / 2 + 1)))) for m in range(1, n_mels + 1): f_m_minus = bin_points[m - 1] @@ -43,15 +38,14 @@ def compute_mfcc(waveform, sample_rate=16000, n_mfcc=40, n_fft=480, hop_length=1 fbank[m - 1, k] = (k - f_m_minus) / (f_m - f_m_minus) for k in range(f_m, f_m_plus): fbank[m - 1, k] = (f_m_plus - k) / (f_m_plus - f_m) - + filter_banks = np.dot(pow_frames, fbank.T) filter_banks = np.where(filter_banks == 0, np.finfo(float).eps, filter_banks) filter_banks = 20 * np.log10(filter_banks) - - # DCT to get MFCCs + mfcc = dct(filter_banks, type=2, axis=1, norm='ortho')[:, :n_mfcc] - - return mfcc.T # (n_mfcc, time_steps) + + return mfcc.T def _load_background_noise(root): @@ -61,7 +55,7 @@ def _load_background_noise(root): if not os.path.exists(noise_dir): print(f"[WARNING] Background noise dir not found: {noise_dir}") return noise_data - + for f in os.listdir(noise_dir): if not f.endswith('.wav'): continue @@ -77,125 +71,106 @@ def _load_background_noise(root): noise_data.append(wav) except Exception as e: print(f"[WARNING] Error reading noise file {fpath}: {e}") - + print(f"Loaded {len(noise_data)} background noise files") return noise_data - class SpeechCommandsDataset(Dataset): - """Google Speech Commands V2 dataset with MFCC features for KWT model""" - def __init__(self, root='./data', subset='training', n_silence=2300): self.root = os.path.join(root, 'SpeechCommands', 'speech_commands_v0.02') self.subset = subset - self.samples = [] # list of (path_or_None, label_str) - self.noise_data = [] # background noise waveforms for silence - - # Check if dataset exists + self.samples = [] + self.noise_data = [] + if not os.path.exists(self.root): raise RuntimeError(f"Dataset not found at {self.root}. Please download manually.") - - # Load background noise for silence class + self.noise_data = _load_background_noise(self.root) - - # Load validation/testing list + val_list = set() test_list = set() - + val_file = os.path.join(self.root, 'validation_list.txt') test_file = os.path.join(self.root, 'testing_list.txt') - + if os.path.exists(val_file): with open(val_file) as f: val_list = set(line.strip() for line in f) if os.path.exists(test_file): with open(test_file) as f: test_list = set(line.strip() for line in f) - - # Collect keyword & unknown samples based on subset + for label_dir in os.listdir(self.root): label_path = os.path.join(self.root, label_dir) if not os.path.isdir(label_path) or label_dir.startswith('_'): continue - + for audio_file in os.listdir(label_path): if not audio_file.endswith('.wav'): continue - + rel_path = os.path.join(label_dir, audio_file) - + if subset == 'validation' and rel_path in val_list: self.samples.append((os.path.join(label_path, audio_file), label_dir)) elif subset == 'testing' and rel_path in test_list: self.samples.append((os.path.join(label_path, audio_file), label_dir)) elif subset == 'training' and rel_path not in val_list and rel_path not in test_list: self.samples.append((os.path.join(label_path, audio_file), label_dir)) - - # Generate silence samples (random 1s crops from background noise) - # Following the paper: silence class created from _background_noise_ + if self.noise_data: if subset == 'training': num_silence = n_silence else: - num_silence = max(1, n_silence // 9) # ~260 for val/test - + num_silence = max(1, n_silence // 9) + for _ in range(num_silence): - # path=None marks this as a silence sample (generated on-the-fly) self.samples.append((None, 'silence')) - + print(f"Added {num_silence} silence samples from background noise") - + print(f"Total {len(self.samples)} samples for {subset}") - + def _get_silence_waveform(self): - """Generate a 1-second silence waveform by random crop from background noise""" target_length = 16000 noise = random.choice(self.noise_data) - + if len(noise) <= target_length: - # Noise file too short, pad with zeros waveform = np.pad(noise, (0, max(0, target_length - len(noise)))) else: start = random.randint(0, len(noise) - target_length) - waveform = noise[start : start + target_length] - + waveform = noise[start: start + target_length] + return waveform - + def __getitem__(self, idx): audio_path, label = self.samples[idx] - + try: if audio_path is None: - # Silence sample: random crop from background noise waveform = self._get_silence_waveform() else: - # Load audio using scipy sample_rate, waveform = wavfile.read(audio_path) - - # Convert to float and normalize if waveform.dtype == np.int16: waveform = waveform.astype(np.float32) / 32768.0 elif waveform.dtype == np.int32: waveform = waveform.astype(np.float32) / 2147483648.0 else: waveform = waveform.astype(np.float32) - - # Pad/trim to 1 second (16000 samples) + target_length = 16000 if len(waveform) < target_length: waveform = np.pad(waveform, (0, target_length - len(waveform))) else: waveform = waveform[:target_length] - - # Compute MFCC features + mfcc = compute_mfcc(waveform, sample_rate=16000, n_mfcc=40) mfcc = torch.tensor(mfcc, dtype=torch.float32) - + except Exception as e: print(f"[WARNING] Error processing sample {idx}: {e}, using zeros") mfcc = torch.zeros(40, 98, dtype=torch.float32) - - # Map label to class index + if label in CLASSES: label_idx = CLASSES.index(label) else: diff --git a/other/DCSL/src/dataset/dataloader.py b/other/DCSL/src/dataset/dataloader.py new file mode 100644 index 0000000..5594bdc --- /dev/null +++ b/other/DCSL/src/dataset/dataloader.py @@ -0,0 +1,135 @@ +import random + +import torch +import torchvision +from collections import defaultdict +from tqdm import tqdm + +from torch.utils.data import DataLoader +import torchvision.transforms as transforms +from transformers import BertTokenizer +from datasets import load_dataset + +from src.dataset.AGNEWS import AGNEWS_DATASET +from src.dataset.SPEECHCOMMANDS import SpeechCommandsDataset + +def AGNEWS(batch_size=None, distribution=None, train=True): + dataset = load_dataset( + 'ag_news', + download_mode='reuse_dataset_if_exists', + cache_dir='./hf_cache' + ) + tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + + if train: + train_data = dataset['train'] + train_target_counts = {k: v for k, v in enumerate(distribution)} + train_by_class = defaultdict(list) + for text, label in zip(train_data['text'], train_data['label']): + train_by_class[label].append((text, label)) + + train_texts, train_labels = [], [] + for label, count in train_target_counts.items(): + samples = random.sample(train_by_class[label], count) + train_texts.extend([t for t, _ in samples]) + train_labels.extend([l for _, l in samples]) + print("Train samples:", len(train_texts), {l: train_labels.count(l) for l in set(train_labels)}) + + train_set = AGNEWS_DATASET(train_texts, train_labels, tokenizer, max_length=128) + train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) + return train_loader + else: + test_data = dataset['test'] + distribution = [500, 500, 500, 500] + test_target_counts = {k: v for k, v in enumerate(distribution)} + test_by_class = defaultdict(list) + for text, label in zip(test_data['text'], test_data['label']): + test_by_class[label].append((text, label)) + + test_texts, test_labels = [], [] + for label, count in test_target_counts.items(): + samples = random.sample(test_by_class[label], count) + test_texts.extend([t for t, _ in samples]) + test_labels.extend([l for _, l in samples]) + + print("Test samples:", len(test_texts), {l: test_labels.count(l) for l in set(test_labels)}) + + test_set = AGNEWS_DATASET(test_texts, test_labels, tokenizer, max_length=128) + test_loader = DataLoader(test_set, batch_size=100, shuffle=False) + return test_loader + +def CIFAR10(batch_size=None, distribution=None, train=True): + if train: + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + + label_to_indices = defaultdict(list) + for idx, (_, label) in tqdm(enumerate(train_set)): + label_to_indices[int(label)].append(idx) + + selected_indices = [] + for label, count in enumerate(distribution): + selected_indices.extend(random.sample(label_to_indices[label], count)) + subset = torch.utils.data.Subset(train_set, selected_indices) + + train_loader = torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=True) + + return train_loader + else: + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + test_loader = torch.utils.data.DataLoader(test_set, batch_size=100, shuffle=False, num_workers=1) + return test_loader + +def SPEECHCOMMANDS(batch_size=None, distribution=None, train=True): + """Google Speech Commands V2 dataset loader for KWT model""" + if train: + dataset = SpeechCommandsDataset(root='./data', subset='training') + + if distribution is not None: + # Build label index from samples list directly (avoid reading wav files) + from src.dataset.SPEECHCOMMANDS import CLASSES + label_to_indices = defaultdict(list) + for idx, (audio_path, label_name) in enumerate(dataset.samples): + if label_name in CLASSES: + label_idx = CLASSES.index(label_name) + else: + label_idx = CLASSES.index('unknown') + label_to_indices[label_idx].append(idx) + + selected_indices = [] + for label, count in enumerate(distribution): + if count > 0 and label in label_to_indices: + available = label_to_indices[label] + selected_indices.extend(random.sample(available, min(count, len(available)))) + + print(f"[DEBUG] Selected {len(selected_indices)} samples after distribution filter") + subset = torch.utils.data.Subset(dataset, selected_indices) + train_loader = DataLoader(subset, batch_size=batch_size, shuffle=True) + else: + train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + + return train_loader + else: + dataset = SpeechCommandsDataset(root='./data', subset='testing') + test_loader = DataLoader(dataset, batch_size=20, shuffle=False) + return test_loader + +def data_loader(data_name=None, batch_size=None, distribution=None, train=True): + if data_name == 'AGNEWS': + data = AGNEWS(batch_size, distribution, train) + elif data_name == 'SPEECHCOMMANDS': + data = SPEECHCOMMANDS(batch_size, distribution, train) + else: + data = CIFAR10(batch_size, distribution, train) + + return data \ No newline at end of file diff --git a/other/DCSL/src/model/BERT_AGNEWS.py b/other/DCSL/src/model/BERT_AGNEWS.py index e5f2fcf..dabf59e 100644 --- a/other/DCSL/src/model/BERT_AGNEWS.py +++ b/other/DCSL/src/model/BERT_AGNEWS.py @@ -196,10 +196,7 @@ def __init__(self, vocab_size=28996, hidden_size=768, num_attention_heads=12, in self.layer14 = BertPooler(hidden_size) if self.start_layer < 15 <= self.end_layer: - self.layer15 = nn.Sequential( - nn.Dropout(dropout_prob), - nn.Linear(hidden_size, 4) - ) + self.layer15 = BertClassifier(hidden_size, 4) def forward(self, input_ids=None, token_type_ids=None, **kwargs): x = input_ids