diff --git a/other/DCSL/src/Log.py b/other/DCSL/src/Log.py index 640ad13..f8f7eac 100644 --- a/other/DCSL/src/Log.py +++ b/other/DCSL/src/Log.py @@ -1,6 +1,5 @@ import logging - class Colors: COLORS = { "header": '\033[95m', @@ -11,23 +10,18 @@ class Colors: "end": '\033[0m' } - class Logger: def __init__(self, log_path, debug_mode=False): - # Thiết lập logger với tên "my_logger" self.logger = logging.getLogger("my_logger") self.logger.setLevel(logging.DEBUG) # Mức log self.debug_mode = debug_mode - # Tạo file handler để ghi log vào file file_handler = logging.FileHandler(log_path) file_handler.setLevel(logging.DEBUG) - # Định dạng log formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler.setFormatter(formatter) - # Gắn file handler vào logger self.logger.addHandler(file_handler) def log_info(self, message): diff --git a/other/DCSL/src/RpcClient.py b/other/DCSL/src/RpcClient.py index fde4e5a..51f2b98 100644 --- a/other/DCSL/src/RpcClient.py +++ b/other/DCSL/src/RpcClient.py @@ -1,15 +1,13 @@ import time import pickle -import random import copy -import torchvision -import torchvision.transforms as transforms - -from collections import defaultdict -from tqdm import tqdm import src.Log + from src.model import * +from src.dataset.dataloader import data_loader + +from peft import LoraConfig, get_peft_model class RpcClient: @@ -22,7 +20,9 @@ def __init__(self, client_id, layer_id, channel, train_func, device): self.response = None self.model = None + self.train_loader = None self.label_count = None + self.peft_config = None self.train_set = None self.label_to_indices = None @@ -56,34 +56,6 @@ def response_message(self, body): if self.label_count is not None: src.Log.print_with_color(f"Label distribution of client: {self.label_count}", "yellow") - # Load training dataset - if self.layer_id == 1 and data_name and not self.train_set and not self.label_to_indices: - if data_name == "MNIST": - transform_train = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) - ]) - self.train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, - transform=transform_train) - - elif data_name == "CIFAR10": - transform_train = transforms.Compose([ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - self.train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, - transform=transform_train) - else: - self.train_set = None - raise ValueError(f"Data name '{data_name}' is not valid.") - - self.label_to_indices = defaultdict(list) - for idx, (_, label) in tqdm(enumerate(self.train_set)): - self.label_to_indices[int(label)].append(idx) - - # Load model if self.model is None: klass = globals()[f'{model_name}_{data_name}'] @@ -99,26 +71,38 @@ def response_message(self, body): lr = self.response["lr"] momentum = self.response["momentum"] sda_size = self.response.get("sda_size", 1) + layer2_devices = self.response.get("layer2_devices", []) - # Read parameters and load to model if state_dict: self.model.load_state_dict(state_dict) - # Start training - if self.layer_id == 1: - selected_indices = [] - for label, count in enumerate(self.label_count): - selected_indices.extend(random.sample(self.label_to_indices[label], count)) + if model_name == 'BERT': + if self.peft_config is None: + self.peft_config = LoraConfig( + task_type="SEQ_CLS", + r=8, lora_alpha=16, lora_dropout=0.1, + bias="none", + target_modules=["query", "key", "value", "dense"] + ) + self.model = get_peft_model(self.model, self.peft_config) + if self.layer_id == 2: + for param in self.model.layer15.parameters(): + param.requires_grad = True + + self.model.to(self.device) - subset = torch.utils.data.Subset(self.train_set, selected_indices) - train_loader = torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=True) + if self.layer_id == 1: + if self.train_loader is None: + self.train_loader = data_loader(data_name, batch_size, self.label_count, train=True) - result, size = self.train_func(self.model, lr, momentum, train_loader, local_round=local_round) + result, size = self.train_func(self.model, lr, momentum, self.train_loader, local_round=local_round, layer2_devices=layer2_devices, model_name=model_name) else: - result, size = self.train_func(self.model, lr, momentum, None, local_round=local_round, sda_size=sda_size) + result, size = self.train_func(self.model, lr, momentum, None, local_round=local_round, sda_size=sda_size, model_name=model_name) + + if model_name == 'BERT': + self.model = self.model.merge_and_unload() - # Stop training, then send parameters to server model_state_dict = copy.deepcopy(self.model.state_dict()) if self.device != "cpu": for key in model_state_dict: @@ -133,7 +117,6 @@ def response_message(self, body): elif action == "STOP": return False - def send_to_server(self, message): self.response = None diff --git a/other/DCSL/src/Scheduler.py b/other/DCSL/src/Scheduler.py index 6a85bd4..60d3bb4 100644 --- a/other/DCSL/src/Scheduler.py +++ b/other/DCSL/src/Scheduler.py @@ -18,9 +18,12 @@ def __init__(self, client_id, layer_id, channel, device): self.device = device self.data_count = 0 - def send_intermediate_output(self, output, labels, trace, data_id=None): + def send_intermediate_output(self, output, labels, trace, data_id=None, target_device_id=None): - forward_queue_name = f'intermediate_queue_{self.layer_id}' + if target_device_id is not None: + forward_queue_name = f'intermediate_queue_{target_device_id}' + else: + forward_queue_name = f'intermediate_queue_{self.layer_id}' self.channel.queue_declare(forward_queue_name, durable=False) @@ -63,12 +66,11 @@ def send_to_server(self, message): routing_key='rpc_queue', body=pickle.dumps(message)) - def train_on_first_layer(self, model, lr, momentum, train_loader=None, local_round=3): - """ - Synchronous training: forward 1 batch → wait for gradient → backward → next batch. - Edge device does NOT send multiple batches before receiving gradient. - """ - optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) + def train_on_first_layer(self, model, lr, momentum, train_loader=None, local_round=3, layer2_devices=None, model_name=None): + if model_name == 'BERT': + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) + else: + optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) backward_queue_name = f'gradient_queue_{self.layer_id}_{self.client_id}' self.channel.queue_declare(queue=backward_queue_name, durable=False) @@ -76,25 +78,42 @@ def train_on_first_layer(self, model, lr, momentum, train_loader=None, local_rou model.to(self.device) + batch_counter = 0 + for i in range(local_round): src.Log.print_with_color(f'Epoch {i}', 'green') with tqdm(total=len(train_loader), desc="Processing", unit="step") as pbar: - for training_data, labels in train_loader: - training_data = training_data.to(self.device) + for batch in train_loader: + if isinstance(batch, dict) and 'input_ids' in batch: + training_data = batch['input_ids'].to(self.device) + attention_mask = batch['attention_mask'].to(self.device) + labels = batch['labels'].to(self.device) + kwargs = {'input_ids': training_data, 'attention_mask': attention_mask} + else: + training_data, labels = batch + training_data = training_data.to(self.device) + labels = labels.to(self.device) + kwargs = {} - # Step 1: Forward data_id = str(uuid.uuid4()) - intermediate_output = model(training_data) + with torch.no_grad(): + if 'input_ids' in kwargs: + intermediate_output = model(**kwargs) + else: + intermediate_output = model(training_data, **kwargs) intermediate_output = intermediate_output.detach().requires_grad_(True) self.data_count += 1 pbar.update(1) - # Step 2: Send smashed data to server - self.send_intermediate_output(intermediate_output, labels, trace=None, data_id=data_id) + target_device_id = None + if layer2_devices: + target_device_id = layer2_devices[batch_counter % len(layer2_devices)] + batch_counter += 1 + + self.send_intermediate_output(intermediate_output, labels, trace=None, data_id=data_id, target_device_id=target_device_id) - # Step 3: Wait for gradient (blocking) while True: method_frame, header_frame, body = self.channel.basic_get( queue=backward_queue_name, auto_ack=True) @@ -102,10 +121,12 @@ def train_on_first_layer(self, model, lr, momentum, train_loader=None, local_rou received_data = pickle.loads(body) gradient = torch.tensor(received_data["data"]).to(self.device) - # Step 4: Backward model.train() optimizer.zero_grad() - output = model(training_data) + if 'input_ids' in kwargs: + output = model(**kwargs) + else: + output = model(training_data, **kwargs) output.backward(gradient=gradient) optimizer.step() break @@ -128,17 +149,11 @@ def train_on_first_layer(self, model, lr, momentum, train_loader=None, local_rou return True time.sleep(0.5) - def _process_sda_batch(self, model, optimizer, criterion, collected): - """ - SDA (Smashed Data Aggregation) — Eq. 4-5 from paper. - Concatenate smashed data from all clients, forward once, - split gradient back to each client. - """ + def _process_sda_batch(self, model, optimizer, criterion, collected, model_name=None): batch_sizes = [item["data"].shape[0] for item in collected] traces = [item["trace"] for item in collected] data_ids = [item["data_id"] for item in collected] - # Eq. 4: S_c = concat(σ_1, σ_2, ..., σ_|D_c|) all_data = np.concatenate([item["data"] for item in collected], axis=0) all_labels = np.concatenate([item["label"] for item in collected], axis=0) @@ -149,8 +164,10 @@ def _process_sda_batch(self, model, optimizer, criterion, collected): optimizer.zero_grad() concat_intermediate.retain_grad() - # Eq. 5: ŷ = f(S_c | W) - output = model(concat_intermediate) + if model_name == 'BERT': + output = model(input_ids=concat_intermediate) + else: + output = model(concat_intermediate) loss = criterion(output, concat_labels.long()) print(f"Loss (SDA, {len(collected)} clients, {sum(batch_sizes)} samples): {loss.item():.4f}") @@ -173,20 +190,18 @@ def _process_sda_batch(self, model, optimizer, criterion, collected): return result - def train_on_last_layer(self, model, lr, momentum, sda_size=1): - """ - SDA: collect exactly 1 batch from each client, - concat and forward once, split gradient back. - Since edge devices are synchronous, no overflow needed. - """ - optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) + def train_on_last_layer(self, model, lr, momentum, sda_size=1, model_name=None): + if model_name == 'BERT': + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) + else: + optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) result = True criterion = nn.CrossEntropyLoss() - forward_queue_name = f'intermediate_queue_{self.layer_id - 1}' + forward_queue_name = f'intermediate_queue_{self.client_id}' self.channel.queue_declare(queue=forward_queue_name, durable=False) self.channel.basic_qos(prefetch_count=1) - print(f'Waiting for intermediate output (SDA size={sda_size}). To exit press CTRL+C') + print(f'Waiting for intermediate output on queue {forward_queue_name} (SDA size={sda_size}). To exit press CTRL+C') model.to(self.device) sda_batch = {} # {client_id: data} — exactly 1 batch per client @@ -200,7 +215,7 @@ def train_on_last_layer(self, model, lr, momentum, sda_size=1): # When we have 1 batch from each client → SDA forward if len(sda_batch) >= sda_size: - batch_result = self._process_sda_batch(model, optimizer, criterion, list(sda_batch.values())) + batch_result = self._process_sda_batch(model, optimizer, criterion, list(sda_batch.values()), model_name=model_name) if not batch_result: result = False sda_batch = {} @@ -214,16 +229,16 @@ def train_on_last_layer(self, model, lr, momentum, sda_size=1): if received_data["action"] == "PAUSE": # Process remaining if sda_batch: - batch_result = self._process_sda_batch(model, optimizer, criterion, list(sda_batch.values())) + batch_result = self._process_sda_batch(model, optimizer, criterion, list(sda_batch.values()), model_name=model_name) if not batch_result: result = False return result - def train_on_device(self, model, lr, momentum, train_loader=None, local_round=None, sda_size=1): + def train_on_device(self, model, lr, momentum, train_loader=None, local_round=None, sda_size=1, layer2_devices=None, model_name=None): self.data_count = 0 if self.layer_id == 1: - result = self.train_on_first_layer(model, lr, momentum, train_loader, local_round) + result = self.train_on_first_layer(model, lr, momentum, train_loader, local_round, layer2_devices=layer2_devices, model_name=model_name) else: - result = self.train_on_last_layer(model, lr, momentum, sda_size) + result = self.train_on_last_layer(model, lr, momentum, sda_size, model_name=model_name) return result, self.data_count \ No newline at end of file diff --git a/other/DCSL/src/Server.py b/other/DCSL/src/Server.py index 3468a27..94f9f29 100644 --- a/other/DCSL/src/Server.py +++ b/other/DCSL/src/Server.py @@ -76,6 +76,9 @@ def __init__(self, config): self.idx = 0 self.current_clients_cluster = 0 + self.current_lr = 0 + self.sda_size = 0 + self.channel.basic_qos(prefetch_count=1) self.reply_channel = self.connection.channel() self.channel.basic_consume(queue='rpc_queue', on_message_callback=self.on_request) @@ -89,16 +92,27 @@ def distribution(self): if self.non_iid: # label_distribution = np.random.dirichlet([self.data_distribution["dirichlet"]["alpha"]] * self.num_label, # self.total_clients[0]) - - label_distribution = np.array([[0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], - [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], - [0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], - [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], - [0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], - [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], + #VGG16 + # label_distribution = np.array([[0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], + # [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], + # [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], + # [0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], + # [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], + # [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], + # [0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], + # [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], + # [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], + # ]) + #KWT + label_distribution = np.array([[0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25], + [0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25], + [0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25], ]) self.label_counts = (label_distribution * self.num_sample).astype(int) @@ -216,8 +230,9 @@ def on_request(self, ch, method, props, body): if self.validation and self.round_result: state_dict_full = self.concatenate_and_avg_clusters() self.avg_state_dict = [] - if not src.Validation.test(self.model_name, self.data_name, state_dict_full, self.logger): + if not src.Validation.test(self.model_name, self.data_name, state_dict_full, self.logger, self.connection): self.logger.log_warning("Training failed!") + self.round = 0 else: torch.save(state_dict_full, f'{self.model_name}_{self.data_name}.pth') self.round -= 1 @@ -228,9 +243,6 @@ def on_request(self, ch, method, props, body): if self.round > 0: current_round = self.global_round - self.round + 1 - # Step decay: reduce LR by lr_decay every lr_step rounds - num_decays = current_round // self.lr_step - self.current_lr = self.lr * (self.lr_decay ** num_decays) self.logger.log_info(f"Start training round {current_round}") self.logger.log_info(f"Learning rate: {self.current_lr}") self.label_ = copy.deepcopy(self.label_counts) @@ -244,6 +256,8 @@ def on_request(self, ch, method, props, body): def notify_clients(self, start=True, register=True, idx=0, avg_model=None): if start: + layer2_device_ids = [str(cid) for (cid, lid, cl) in self.list_clients if lid == 2] + if register: klass = globals()[f'{self.model_name}_{self.data_name}'] @@ -275,7 +289,10 @@ def notify_clients(self, start=True, register=True, idx=0, avg_model=None): keys = state_dict.keys() for key in keys: - state_dict[key] = full_state_dict[key] + if key in full_state_dict: + state_dict[key] = full_state_dict[key] + else: + state_dict[key] = full_state_dict[key] self.logger.log_info("Model loaded successfully.") else: self.logger.log_info(f"File {filepath} does not exist.") @@ -298,7 +315,8 @@ def notify_clients(self, start=True, register=True, idx=0, avg_model=None): "momentum": self.momentum, "label_count": label, "local_round": self.local_round, - "sda_size": self.sda_size + "sda_size": self.sda_size, + "layer2_devices": layer2_device_ids } self.send_to_response(client_id, pickle.dumps(response)) @@ -320,7 +338,8 @@ def notify_clients(self, start=True, register=True, idx=0, avg_model=None): "momentum": self.momentum, "label_count": label, "local_round": self.local_round, - "sda_size": self.sda_size + "sda_size": self.sda_size, + "layer2_devices": layer2_device_ids } self.send_to_response(client_id, pickle.dumps(response)) diff --git a/other/DCSL/src/Utils.py b/other/DCSL/src/Utils.py index dc42a72..2c7c815 100644 --- a/other/DCSL/src/Utils.py +++ b/other/DCSL/src/Utils.py @@ -1,5 +1,3 @@ -import numpy as np -import random import pika import torch @@ -34,12 +32,6 @@ def delete_old_queues(address, username, password, virtual_host): return False def fedavg_state_dicts(state_dicts, weights = None): - """ - Trung bình (FedAvg) một list các state_dict. - - state_dicts: list các dict {param_name: tensor} - - weights: list trọng số tương ứng (mặc định None nghĩa là mỗi model weight=1) - Trả về một dict {param_name: tensor_avg} - """ num = len(state_dicts) if num == 0: raise ValueError("fedavg_state_dicts: không có state_dict nào để trung bình.") @@ -48,12 +40,10 @@ def fedavg_state_dicts(state_dicts, weights = None): weights = [1.0] * num total_w = sum(weights) - # Tập hợp tất cả key all_keys = set().union(*(sd.keys() for sd in state_dicts)) avg_dict = {} for key in all_keys: - # gom tensor + weight, xử lý NaN acc = None for sd, w in zip(state_dicts, weights): if key not in sd: @@ -64,10 +54,8 @@ def fedavg_state_dicts(state_dicts, weights = None): t = t * w acc = t if acc is None else acc + t - # chia trung bình avg = acc / total_w - # cast về dtype gốc orig = next(sd[key] for sd in state_dicts if key in sd) if orig.dtype in (torch.int8, torch.int16, torch.int32, torch.int64, torch.bool): avg = avg.round().to(orig.dtype) diff --git a/other/DCSL/src/Validation.py b/other/DCSL/src/Validation.py index eb7dd7a..ebc9a25 100644 --- a/other/DCSL/src/Validation.py +++ b/other/DCSL/src/Validation.py @@ -1,33 +1,14 @@ - import numpy as np import math from tqdm import tqdm -import torchvision -import torchvision.transforms as transforms -import torch.nn as nn - +from src.dataset.dataloader import data_loader from src.model import * -def test(model_name, data_name, state_dict_full, logger): +def test(model_name, data_name, state_dict_full, logger, server_connection=None): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if data_name == "MNIST": - transform_test = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) - ]) - testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test) - - elif data_name == "CIFAR10": - transform_test = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) - else: - raise ValueError(f"Data name '{data_name}' is not valid.") - test_loader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) + test_loader = data_loader(data_name=data_name, train=False) criterion = nn.CrossEntropyLoss() @@ -47,15 +28,28 @@ def test(model_name, data_name, state_dict_full, logger): total = 0 with torch.no_grad(): - for data, target in tqdm(test_loader): - data = data.to(device) - target = target.to(device) - output = model(data) + for i, batch in enumerate(tqdm(test_loader)): + if isinstance(batch, dict) and 'input_ids' in batch: + data = batch['input_ids'].to(device) + target = batch['labels'].to(device) + output = model(data, attention_mask=batch['attention_mask'].to(device)) + else: + data, target = batch + data = data.to(device) + target = target.to(device) + output = model(data) + loss = criterion(output, target) test_loss += loss.item() * target.size(0) pred = output.argmax(dim=1) correct += pred.eq(target).sum().item() total += target.size(0) + + if server_connection is not None and i % 5 == 0: + try: + server_connection.process_data_events(time_limit=0) + except Exception: + pass test_loss /= total accuracy = 100.0 * correct / total diff --git a/other/DCSL/src/dataset/AGNEWS.py b/other/DCSL/src/dataset/AGNEWS.py new file mode 100644 index 0000000..4fc37dc --- /dev/null +++ b/other/DCSL/src/dataset/AGNEWS.py @@ -0,0 +1,30 @@ +import torch + +class AGNEWS_DATASET(torch.utils.data.Dataset): + def __init__(self, texts, labels, tokenizer, max_length=128): + self.texts = texts + self.labels = labels + self.tokenizer = tokenizer + self.max_length = max_length + + def __len__(self): + return len(self.texts) + + def __getitem__(self, idx): + text = str(self.texts[idx]) + label = self.labels[idx] + + # Tokenize + encoding = self.tokenizer( + text, + truncation=True, + padding='max_length', + max_length=self.max_length, + return_tensors='pt' + ) + + return { + 'input_ids': encoding['input_ids'].flatten(), + 'attention_mask': encoding['attention_mask'].flatten(), + 'labels': torch.tensor(label, dtype=torch.long) + } \ No newline at end of file diff --git a/other/DCSL/src/dataset/SPEECHCOMMANDS.py b/other/DCSL/src/dataset/SPEECHCOMMANDS.py new file mode 100644 index 0000000..4caf008 --- /dev/null +++ b/other/DCSL/src/dataset/SPEECHCOMMANDS.py @@ -0,0 +1,182 @@ +import os +import random +import torch +import numpy as np +from torch.utils.data import Dataset +from scipy.io import wavfile +from scipy.fftpack import dct + +CLASSES = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go', 'silence', 'unknown'] + +def compute_mfcc(waveform, sample_rate=16000, n_mfcc=40, n_fft=480, hop_length=160, n_mels=40): + """Compute MFCC features using scipy (no torchaudio backend needed)""" + emphasized = np.append(waveform[0], waveform[1:] - 0.97 * waveform[:-1]) + + frame_length = n_fft + num_frames = 1 + (len(emphasized) - frame_length) // hop_length + frames = np.zeros((num_frames, frame_length)) + for i in range(num_frames): + frames[i] = emphasized[i * hop_length: i * hop_length + frame_length] + + frames *= np.hamming(frame_length) + + mag_frames = np.absolute(np.fft.rfft(frames, n_fft)) + pow_frames = (1.0 / n_fft) * (mag_frames ** 2) + + low_freq_mel = 0 + high_freq_mel = 2595 * np.log10(1 + (sample_rate / 2) / 700) + mel_points = np.linspace(low_freq_mel, high_freq_mel, n_mels + 2) + hz_points = 700 * (10 ** (mel_points / 2595) - 1) + bin_points = np.floor((n_fft + 1) * hz_points / sample_rate).astype(int) + + fbank = np.zeros((n_mels, int(np.floor(n_fft / 2 + 1)))) + for m in range(1, n_mels + 1): + f_m_minus = bin_points[m - 1] + f_m = bin_points[m] + f_m_plus = bin_points[m + 1] + for k in range(f_m_minus, f_m): + fbank[m - 1, k] = (k - f_m_minus) / (f_m - f_m_minus) + for k in range(f_m, f_m_plus): + fbank[m - 1, k] = (f_m_plus - k) / (f_m_plus - f_m) + + filter_banks = np.dot(pow_frames, fbank.T) + filter_banks = np.where(filter_banks == 0, np.finfo(float).eps, filter_banks) + filter_banks = 20 * np.log10(filter_banks) + + mfcc = dct(filter_banks, type=2, axis=1, norm='ortho')[:, :n_mfcc] + + return mfcc.T + + +def _load_background_noise(root): + """Load all background noise files from _background_noise_ directory""" + noise_dir = os.path.join(root, '_background_noise_') + noise_data = [] + if not os.path.exists(noise_dir): + print(f"[WARNING] Background noise dir not found: {noise_dir}") + return noise_data + + for f in os.listdir(noise_dir): + if not f.endswith('.wav'): + continue + fpath = os.path.join(noise_dir, f) + try: + sr, wav = wavfile.read(fpath) + if wav.dtype == np.int16: + wav = wav.astype(np.float32) / 32768.0 + elif wav.dtype == np.int32: + wav = wav.astype(np.float32) / 2147483648.0 + else: + wav = wav.astype(np.float32) + noise_data.append(wav) + except Exception as e: + print(f"[WARNING] Error reading noise file {fpath}: {e}") + + print(f"Loaded {len(noise_data)} background noise files") + return noise_data + +class SpeechCommandsDataset(Dataset): + def __init__(self, root='./data', subset='training', n_silence=2300): + self.root = os.path.join(root, 'SpeechCommands', 'speech_commands_v0.02') + self.subset = subset + self.samples = [] + self.noise_data = [] + + if not os.path.exists(self.root): + raise RuntimeError(f"Dataset not found at {self.root}. Please download manually.") + + self.noise_data = _load_background_noise(self.root) + + val_list = set() + test_list = set() + + val_file = os.path.join(self.root, 'validation_list.txt') + test_file = os.path.join(self.root, 'testing_list.txt') + + if os.path.exists(val_file): + with open(val_file) as f: + val_list = set(line.strip() for line in f) + if os.path.exists(test_file): + with open(test_file) as f: + test_list = set(line.strip() for line in f) + + for label_dir in os.listdir(self.root): + label_path = os.path.join(self.root, label_dir) + if not os.path.isdir(label_path) or label_dir.startswith('_'): + continue + + for audio_file in os.listdir(label_path): + if not audio_file.endswith('.wav'): + continue + + rel_path = os.path.join(label_dir, audio_file) + + if subset == 'validation' and rel_path in val_list: + self.samples.append((os.path.join(label_path, audio_file), label_dir)) + elif subset == 'testing' and rel_path in test_list: + self.samples.append((os.path.join(label_path, audio_file), label_dir)) + elif subset == 'training' and rel_path not in val_list and rel_path not in test_list: + self.samples.append((os.path.join(label_path, audio_file), label_dir)) + + if self.noise_data: + if subset == 'training': + num_silence = n_silence + else: + num_silence = max(1, n_silence // 9) + + for _ in range(num_silence): + self.samples.append((None, 'silence')) + + print(f"Added {num_silence} silence samples from background noise") + + print(f"Total {len(self.samples)} samples for {subset}") + + def _get_silence_waveform(self): + target_length = 16000 + noise = random.choice(self.noise_data) + + if len(noise) <= target_length: + waveform = np.pad(noise, (0, max(0, target_length - len(noise)))) + else: + start = random.randint(0, len(noise) - target_length) + waveform = noise[start: start + target_length] + + return waveform + + def __getitem__(self, idx): + audio_path, label = self.samples[idx] + + try: + if audio_path is None: + waveform = self._get_silence_waveform() + else: + sample_rate, waveform = wavfile.read(audio_path) + if waveform.dtype == np.int16: + waveform = waveform.astype(np.float32) / 32768.0 + elif waveform.dtype == np.int32: + waveform = waveform.astype(np.float32) / 2147483648.0 + else: + waveform = waveform.astype(np.float32) + + target_length = 16000 + if len(waveform) < target_length: + waveform = np.pad(waveform, (0, target_length - len(waveform))) + else: + waveform = waveform[:target_length] + + mfcc = compute_mfcc(waveform, sample_rate=16000, n_mfcc=40) + mfcc = torch.tensor(mfcc, dtype=torch.float32) + + except Exception as e: + print(f"[WARNING] Error processing sample {idx}: {e}, using zeros") + mfcc = torch.zeros(40, 98, dtype=torch.float32) + + if label in CLASSES: + label_idx = CLASSES.index(label) + else: + label_idx = CLASSES.index('unknown') + + return mfcc, label_idx + + def __len__(self): + return len(self.samples) diff --git a/other/DCSL/src/dataset/dataloader.py b/other/DCSL/src/dataset/dataloader.py new file mode 100644 index 0000000..5594bdc --- /dev/null +++ b/other/DCSL/src/dataset/dataloader.py @@ -0,0 +1,135 @@ +import random + +import torch +import torchvision +from collections import defaultdict +from tqdm import tqdm + +from torch.utils.data import DataLoader +import torchvision.transforms as transforms +from transformers import BertTokenizer +from datasets import load_dataset + +from src.dataset.AGNEWS import AGNEWS_DATASET +from src.dataset.SPEECHCOMMANDS import SpeechCommandsDataset + +def AGNEWS(batch_size=None, distribution=None, train=True): + dataset = load_dataset( + 'ag_news', + download_mode='reuse_dataset_if_exists', + cache_dir='./hf_cache' + ) + tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + + if train: + train_data = dataset['train'] + train_target_counts = {k: v for k, v in enumerate(distribution)} + train_by_class = defaultdict(list) + for text, label in zip(train_data['text'], train_data['label']): + train_by_class[label].append((text, label)) + + train_texts, train_labels = [], [] + for label, count in train_target_counts.items(): + samples = random.sample(train_by_class[label], count) + train_texts.extend([t for t, _ in samples]) + train_labels.extend([l for _, l in samples]) + print("Train samples:", len(train_texts), {l: train_labels.count(l) for l in set(train_labels)}) + + train_set = AGNEWS_DATASET(train_texts, train_labels, tokenizer, max_length=128) + train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) + return train_loader + else: + test_data = dataset['test'] + distribution = [500, 500, 500, 500] + test_target_counts = {k: v for k, v in enumerate(distribution)} + test_by_class = defaultdict(list) + for text, label in zip(test_data['text'], test_data['label']): + test_by_class[label].append((text, label)) + + test_texts, test_labels = [], [] + for label, count in test_target_counts.items(): + samples = random.sample(test_by_class[label], count) + test_texts.extend([t for t, _ in samples]) + test_labels.extend([l for _, l in samples]) + + print("Test samples:", len(test_texts), {l: test_labels.count(l) for l in set(test_labels)}) + + test_set = AGNEWS_DATASET(test_texts, test_labels, tokenizer, max_length=128) + test_loader = DataLoader(test_set, batch_size=100, shuffle=False) + return test_loader + +def CIFAR10(batch_size=None, distribution=None, train=True): + if train: + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + + label_to_indices = defaultdict(list) + for idx, (_, label) in tqdm(enumerate(train_set)): + label_to_indices[int(label)].append(idx) + + selected_indices = [] + for label, count in enumerate(distribution): + selected_indices.extend(random.sample(label_to_indices[label], count)) + subset = torch.utils.data.Subset(train_set, selected_indices) + + train_loader = torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=True) + + return train_loader + else: + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + test_loader = torch.utils.data.DataLoader(test_set, batch_size=100, shuffle=False, num_workers=1) + return test_loader + +def SPEECHCOMMANDS(batch_size=None, distribution=None, train=True): + """Google Speech Commands V2 dataset loader for KWT model""" + if train: + dataset = SpeechCommandsDataset(root='./data', subset='training') + + if distribution is not None: + # Build label index from samples list directly (avoid reading wav files) + from src.dataset.SPEECHCOMMANDS import CLASSES + label_to_indices = defaultdict(list) + for idx, (audio_path, label_name) in enumerate(dataset.samples): + if label_name in CLASSES: + label_idx = CLASSES.index(label_name) + else: + label_idx = CLASSES.index('unknown') + label_to_indices[label_idx].append(idx) + + selected_indices = [] + for label, count in enumerate(distribution): + if count > 0 and label in label_to_indices: + available = label_to_indices[label] + selected_indices.extend(random.sample(available, min(count, len(available)))) + + print(f"[DEBUG] Selected {len(selected_indices)} samples after distribution filter") + subset = torch.utils.data.Subset(dataset, selected_indices) + train_loader = DataLoader(subset, batch_size=batch_size, shuffle=True) + else: + train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + + return train_loader + else: + dataset = SpeechCommandsDataset(root='./data', subset='testing') + test_loader = DataLoader(dataset, batch_size=20, shuffle=False) + return test_loader + +def data_loader(data_name=None, batch_size=None, distribution=None, train=True): + if data_name == 'AGNEWS': + data = AGNEWS(batch_size, distribution, train) + elif data_name == 'SPEECHCOMMANDS': + data = SPEECHCOMMANDS(batch_size, distribution, train) + else: + data = CIFAR10(batch_size, distribution, train) + + return data \ No newline at end of file diff --git a/other/DCSL/src/model/BERT_AGNEWS.py b/other/DCSL/src/model/BERT_AGNEWS.py new file mode 100644 index 0000000..dabf59e --- /dev/null +++ b/other/DCSL/src/model/BERT_AGNEWS.py @@ -0,0 +1,218 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class DotDict(dict): + def __getattr__(self, k): + try: return self[k] + except KeyError: raise AttributeError(k) + def __setattr__(self, k, v): self[k] = v + def __delattr__(self, k): del self[k] + +# AG_NEWS have 4 layers: World, Sports, Business, Sci/Tech +class BertEmbeddings(nn.Module): + def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, dropout_prob): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) + self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, input_ids, token_type_ids=None): + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + +class BertSdpaSelfAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, dropout_prob): + super(BertSdpaSelfAttention, self).__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(hidden_size, self.all_head_size) + self.key = nn.Linear(hidden_size, self.all_head_size) + self.value = nn.Linear(hidden_size, self.all_head_size) + self.dropout = nn.Dropout(dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + import math + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + +class BertSelfOutput(nn.Module): + def __init__(self, hidden_size, dropout_prob): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + +class BertAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, dropout_prob): + super(BertAttention, self).__init__() + self.self = BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob) + self.output = BertSelfOutput(hidden_size, dropout_prob) + + def forward(self, hidden_states): + self_output = self.self(hidden_states) + attention_output = self.output(self_output, hidden_states) + return attention_output + +class BertIntermediate(nn.Module): + def __init__(self, hidden_size, intermediate_size): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear(hidden_size, intermediate_size) + self.intermediate_act_fn = nn.GELU() + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + +class BertOutput(nn.Module): + def __init__(self, hidden_size, intermediate_size, dropout_prob): + super(BertOutput, self).__init__() + self.dense = nn.Linear(intermediate_size, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + +class BertLayer(nn.Module): + def __init__(self, hidden_size, num_attention_heads, intermediate_size, dropout_prob): + super(BertLayer, self).__init__() + self.attention = BertAttention(hidden_size, num_attention_heads, dropout_prob) + self.intermediate = BertIntermediate(hidden_size, intermediate_size) + self.output = BertOutput(hidden_size, intermediate_size, dropout_prob) + + def forward(self, hidden_states): + attention_output = self.attention(hidden_states) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + +class BertPooler(nn.Module): + def __init__(self, hidden_size): + super(BertPooler, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + +class BertClassifier(nn.Module): + def __init__(self, hidden_size, num_labels, dropout_prob=0.1): + super(BertClassifier, self).__init__() + self.dropout = nn.Dropout(dropout_prob) + self.classifier = nn.Linear(hidden_size, num_labels) + + def forward(self, pooled_output): + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + return logits + +class BERT_AGNEWS(nn.Module): + def __init__(self, vocab_size=28996, hidden_size=768, num_attention_heads=12, intermediate_size=3072, + max_position_embeddings=512, type_vocab_size=2, dropout_prob=0.1, n_block=12, + start_layer=0, end_layer=15): + super(BERT_AGNEWS, self).__init__() + self.start_layer = start_layer + self.end_layer = 15 if end_layer == -1 else end_layer + self.config = DotDict( + model_type="bert", + vocab_size=vocab_size, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, intermediate_size=intermediate_size, + max_position_embeddings=max_position_embeddings, + bos_token_id=101, eos_token_id=102, pad_token_id=0, + is_encoder_decoder=False, tie_word_embeddings=False, + use_return_dict=True, output_attentions=False, output_hidden_states=False + ) + + if self.start_layer < 1 <= self.end_layer: + self.layer1 = BertEmbeddings(vocab_size=vocab_size, hidden_size=hidden_size, + max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size, dropout_prob=dropout_prob) + + for i in range(12): + layer_idx = 2 + i + if self.start_layer < layer_idx <= self.end_layer: + setattr(self, f'layer{layer_idx}', + BertLayer(hidden_size, num_attention_heads, intermediate_size, dropout_prob)) + + if self.start_layer < 14 <= self.end_layer: + self.layer14 = BertPooler(hidden_size) + + if self.start_layer < 15 <= self.end_layer: + self.layer15 = BertClassifier(hidden_size, 4) + + def forward(self, input_ids=None, token_type_ids=None, **kwargs): + x = input_ids + if self.start_layer < 1 <= self.end_layer: + x = self.layer1(x, token_type_ids) + + for i in range(12): + layer_idx = 2 + i + if self.start_layer < layer_idx <= self.end_layer: + layer_module = getattr(self, f'layer{layer_idx}') + x = layer_module(x) + + if self.start_layer < 14 <= self.end_layer: + x = self.layer14(x) + + if self.start_layer < 15 <= self.end_layer: + x = self.layer15(x) + + return x \ No newline at end of file diff --git a/other/DCSL/src/model/BERT_EMOTION.py b/other/DCSL/src/model/BERT_EMOTION.py deleted file mode 100644 index efb4ab9..0000000 --- a/other/DCSL/src/model/BERT_EMOTION.py +++ /dev/null @@ -1,428 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -# Configuration constants -NUM_SAMPLES = 5000 -num_labels = 6 -vocab_size = 30522 -hidden_size = 768 -num_hidden_layers = 12 -num_attention_heads = 12 -intermediate_size = 3072 -max_position_embeddings = 512 -type_vocab_size = 2 -dropout_prob = 0.1 - -# BertEmbeddings class -class BertEmbeddings(nn.Module): - def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, dropout_prob): - super(BertEmbeddings, self).__init__() - self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=0) - self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) - self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) - self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12) - self.dropout = nn.Dropout(dropout_prob) - - def forward(self, input_ids, token_type_ids=None): - seq_length = input_ids.size(1) - position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(input_ids) - - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - words_embeddings = self.word_embeddings(input_ids) - position_embeddings = self.position_embeddings(position_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - - embeddings = words_embeddings + position_embeddings + token_type_embeddings - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings) - return embeddings - -# BertSdpaSelfAttention class -class BertSdpaSelfAttention(nn.Module): - def __init__(self, hidden_size, num_attention_heads, dropout_prob): - super(BertSdpaSelfAttention, self).__init__() - self.num_attention_heads = num_attention_heads - self.attention_head_size = int(hidden_size / num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear(hidden_size, self.all_head_size) - self.key = nn.Linear(hidden_size, self.all_head_size) - self.value = nn.Linear(hidden_size, self.all_head_size) - self.dropout = nn.Dropout(dropout_prob) - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward(self, hidden_states, attention_mask=None): - batch_size, seq_length, hidden_size = hidden_states.size() - - # Create query, key, value projections - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) - - # Reshape for multi-head attention - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) - - # Perform attention score calculation - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - - # Scale attention scores - import math - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - - # Apply attention mask if provided - if attention_mask is not None: - # Reshape attention_mask from [batch_size, seq_length] to [batch_size, 1, 1, seq_length] - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - # Convert 1s (valid tokens) to 0s and 0s (padding) to large negative values - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - attention_scores = attention_scores + extended_attention_mask - - # Apply softmax to get probabilities - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = self.dropout(attention_probs) - - # Calculate context by attending to values - context_layer = torch.matmul(attention_probs, value_layer) - - # Reshape back to [batch_size, seq_length, hidden_size] - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) - - return context_layer - -# BertSelfOutput class -class BertSelfOutput(nn.Module): - def __init__(self, hidden_size, dropout_prob): - super(BertSelfOutput, self).__init__() - self.dense = nn.Linear(hidden_size, hidden_size) - self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12) - self.dropout = nn.Dropout(dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - -# BertAttention class -class BertAttention(nn.Module): - def __init__(self, hidden_size, num_attention_heads, dropout_prob): - super(BertAttention, self).__init__() - self.self = BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob) - self.output = BertSelfOutput(hidden_size, dropout_prob) - - def forward(self, hidden_states, attention_mask=None): - self_output = self.self(hidden_states, attention_mask) - attention_output = self.output(self_output, hidden_states) - return attention_output - -# BertIntermediate class -class BertIntermediate(nn.Module): - def __init__(self, hidden_size, intermediate_size): - super(BertIntermediate, self).__init__() - self.dense = nn.Linear(hidden_size, intermediate_size) - self.intermediate_act_fn = nn.GELU() - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - return hidden_states - -# BertOutput class -class BertOutput(nn.Module): - def __init__(self, hidden_size, intermediate_size, dropout_prob): - super(BertOutput, self).__init__() - self.dense = nn.Linear(intermediate_size, hidden_size) - self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12) - self.dropout = nn.Dropout(dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -# BertPooler class -class BertPooler(nn.Module): - def __init__(self, hidden_size): - super(BertPooler, self).__init__() - self.dense = nn.Linear(hidden_size, hidden_size) - self.activation = nn.Tanh() - - def forward(self, hidden_states): - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output - -# BertClassifier class -class BertClassifier(nn.Module): - def __init__(self, hidden_size, num_labels, dropout_prob=0.1): - super(BertClassifier, self).__init__() - self.dropout = nn.Dropout(dropout_prob) - self.classifier = nn.Linear(hidden_size, num_labels) - - def forward(self, pooled_output): - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - return logits - -# Complete BERT Model -class BERT_EMOTION(nn.Module): - def __init__(self,start_layer= 0, end_layer= 27, vocab_size=30522, hidden_size=768, intermediate_size=3072, - num_attention_heads=12, num_labels=4, max_position_embeddings=512, - type_vocab_size=2, dropout_prob=0.1, num_hidden_layers=12): - - super(BERT_EMOTION, self).__init__() - - self.start_layer = start_layer - self.end_layer = end_layer - - if (self.start_layer < 1) and (self.end_layer >= 1): - self.layer1 = BertEmbeddings(vocab_size, hidden_size , max_position_embeddings, type_vocab_size, dropout_prob) - - if (self.start_layer < 2) and (self.end_layer >= 2): - self.layer2 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 3) and (self.end_layer >= 3): - self.layer3 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 4) and (self.end_layer >= 4): - self.layer4 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 5) and (self.end_layer >= 5): - self.layer5 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 6) and (self.end_layer >= 6): - self.layer6 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 7) and (self.end_layer >= 7): - self.layer7 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 8) and (self.end_layer >= 8): - self.layer8 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 9) and (self.end_layer >= 9): - self.layer9 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 10) and (self.end_layer >= 10): - self.layer10 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 11) and (self.end_layer >= 11): - self.layer11 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 12) and (self.end_layer >= 12): - self.layer12 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 13) and (self.end_layer >= 13): - self.layer13 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 14) and (self.end_layer >= 14): - self.layer14 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 15) and (self.end_layer >= 15): - self.layer15 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 16) and (self.end_layer >= 16): - self.layer16 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 17) and (self.end_layer >= 17): - self.layer17 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 18) and (self.end_layer >= 18): - self.layer18 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 19) and (self.end_layer >= 19): - self.layer19 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 20) and (self.end_layer >= 20): - self.layer20 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 21) and (self.end_layer >= 21): - self.layer21 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 22) and (self.end_layer >= 22): - self.layer22 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 23) and (self.end_layer >= 23): - self.layer23 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 24) and (self.end_layer >= 24): - self.layer24 = nn.ModuleList([ - BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), - BertSelfOutput(hidden_size, dropout_prob) - ]) - - if (self.start_layer < 25) and (self.end_layer >= 25): - self.layer25 = nn.ModuleList([ - BertIntermediate(hidden_size, intermediate_size), - BertOutput(hidden_size, intermediate_size, dropout_prob) - ]) - - if (self.start_layer < 26) and (self.end_layer >= 26): - self.layer26 = BertPooler(hidden_size) - - if (self.start_layer < 27) and (self.end_layer >= 27): - self.layer27 = BertClassifier(hidden_size, num_labels, dropout_prob) - - def forward(self, x, attention_mask=None, token_type_ids=None): - if (self.start_layer < 1) and (self.end_layer >= 1): - x = self.layer1(x, token_type_ids) - - if (self.start_layer < 2) and (self.end_layer >= 2): - x = self.layer2[1](self.layer2[0](x, attention_mask), x) - - if (self.start_layer < 3) and (self.end_layer >= 3): - x = self.layer3[1](self.layer3[0](x), x) - - if (self.start_layer < 4) and (self.end_layer >= 4): - x = self.layer4[1](self.layer4[0](x, attention_mask), x) - - if (self.start_layer < 5) and (self.end_layer >= 5): - x = self.layer5[1](self.layer5[0](x), x) - - if (self.start_layer < 6) and (self.end_layer >= 6): - x = self.layer6[1](self.layer6[0](x, attention_mask), x) - - if (self.start_layer < 7) and (self.end_layer >= 7): - x = self.layer7[1](self.layer7[0](x), x) - - if (self.start_layer < 8) and (self.end_layer >= 8): - x = self.layer8[1](self.layer8[0](x, attention_mask), x) - - if (self.start_layer < 9) and (self.end_layer >= 9): - x = self.layer9[1](self.layer9[0](x), x) - - if (self.start_layer < 10) and (self.end_layer >= 10): - x = self.layer10[1](self.layer10[0](x, attention_mask), x) - - if (self.start_layer < 11) and (self.end_layer >= 11): - x = self.layer11[1](self.layer11[0](x), x) - - if (self.start_layer < 12) and (self.end_layer >= 12): - x = self.layer12[1](self.layer12[0](x, attention_mask), x) - - if (self.start_layer < 13) and (self.end_layer >= 13): - x = self.layer13[1](self.layer13[0](x), x) - - if (self.start_layer < 14) and (self.end_layer >= 14): - x = self.layer14[1](self.layer14[0](x, attention_mask), x) - - if (self.start_layer < 15) and (self.end_layer >= 15): - x = self.layer15[1](self.layer15[0](x), x) - - if (self.start_layer < 16) and (self.end_layer >= 16): - x = self.layer16[1](self.layer16[0](x, attention_mask), x) - - if (self.start_layer < 17) and (self.end_layer >= 17): - x = self.layer17[1](self.layer17[0](x), x) - - if (self.start_layer < 18) and (self.end_layer >= 18): - x = self.layer18[1](self.layer18[0](x, attention_mask), x) - - if (self.start_layer < 19) and (self.end_layer >= 19): - x = self.layer19[1](self.layer19[0](x), x) - - if (self.start_layer < 20) and (self.end_layer >= 20): - x = self.layer20[1](self.layer20[0](x, attention_mask), x) - - if (self.start_layer < 21) and (self.end_layer >= 21): - x = self.layer21[1](self.layer21[0](x), x) - - if (self.start_layer < 22) and (self.end_layer >= 22): - x = self.layer22[1](self.layer22[0](x, attention_mask), x) - - if (self.start_layer < 23) and (self.end_layer >= 23): - x = self.layer23[1](self.layer23[0](x), x) - - if (self.start_layer < 24) and (self.end_layer >= 24): - x = self.layer24[1](self.layer24[0](x, attention_mask), x) - - if (self.start_layer < 25) and (self.end_layer >= 25): - x = self.layer25[1](self.layer25[0](x), x) - - if (self.start_layer < 26) and (self.end_layer >= 26): - x = self.layer26(x) - - if (self.start_layer < 27) and (self.end_layer >= 27): - x = self.layer27(x) - - return x diff --git a/other/DCSL/src/model/KWT_SPEECHCOMMANDS.py b/other/DCSL/src/model/KWT_SPEECHCOMMANDS.py new file mode 100644 index 0000000..4f716b1 --- /dev/null +++ b/other/DCSL/src/model/KWT_SPEECHCOMMANDS.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn + + +class TransformerEncoderBlock(nn.Module): + def __init__(self, embed_dim, num_heads=1, mlp_dim=256): + super().__init__() + self.ln1 = nn.LayerNorm(embed_dim) + self.mha = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) + self.ln2 = nn.LayerNorm(embed_dim) + self.mlp = nn.Sequential( + nn.Linear(embed_dim, mlp_dim), + nn.GELU(), + nn.Linear(mlp_dim, embed_dim) + ) + + def forward(self, x): + # Attention + residual + _x = self.ln1(x) + x_attn = self.mha(_x, _x, _x)[0] + x = x + x_attn + # MLP + residual + x_mlp = self.mlp(self.ln2(x)) + x = x + x_mlp + return x + + +class KWT_SPEECHCOMMANDS(nn.Module): + """ + KWT-1: dim=64, mlp_dim=256, heads=1, layers=12 + Layers: + 1: Linear embedding (n_mfcc -> embed_dim) + 2: CLS token concatenation + 3: Positional embedding + Dropout + 4-15: 12x Transformer encoder blocks + 16: LayerNorm + 17: Classification head + """ + + def __init__(self, start_layer=0, end_layer=17): + super().__init__() + self.start_layer = start_layer + self.end_layer = 17 if end_layer == -1 else end_layer + + n_mfcc = 40 + time_steps = 98 + embed_dim = 64 + num_heads = 1 + mlp_dim = 256 + num_classes = 12 + dropout = 0.1 + + # Layer 1: Linear embedding + if self.start_layer < 1 <= self.end_layer: + self.layer1 = nn.Linear(n_mfcc, embed_dim) + + # Layer 2: CLS token + if self.start_layer < 2 <= self.end_layer: + self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) + + # Layer 3: Positional embedding + Dropout + if self.start_layer < 3 <= self.end_layer: + self.pos_embed = nn.Parameter(torch.randn(1, time_steps + 1, embed_dim)) + self.dropout = nn.Dropout(dropout) + + # Layers 4-15: 12x Transformer encoder blocks + for i in range(12): + layer_idx = 4 + i + if self.start_layer < layer_idx <= self.end_layer: + setattr(self, f'layer{layer_idx}', + TransformerEncoderBlock(embed_dim, num_heads, mlp_dim)) + + # Layer 16: LayerNorm + if self.start_layer < 16 <= self.end_layer: + self.layer16 = nn.LayerNorm(embed_dim) + + # Layer 17: Classification head + if self.start_layer < 17 <= self.end_layer: + self.layer17 = nn.Linear(embed_dim, num_classes) + + self._init_weights() + + def _init_weights(self): + if hasattr(self, 'cls_token'): + nn.init.trunc_normal_(self.cls_token, std=0.02) + if hasattr(self, 'pos_embed'): + nn.init.trunc_normal_(self.pos_embed, std=0.02) + + def forward(self, x): + + # Layer 1: Linear embedding + if self.start_layer < 1 <= self.end_layer: + x = x.transpose(1, 2) + x = self.layer1(x) + + # Layer 2: CLS token + if self.start_layer < 2 <= self.end_layer: + cls_token = self.cls_token.expand(x.size(0), -1, -1) + x = torch.cat([cls_token, x], dim=1) + + # Layer 3: Positional embedding + Dropout + if self.start_layer < 3 <= self.end_layer: + x = x + self.pos_embed + x = self.dropout(x) + + # Layers 4-15: 12x Transformer blocks + for i in range(12): + layer_idx = 4 + i + if self.start_layer < layer_idx <= self.end_layer: + x = getattr(self, f'layer{layer_idx}')(x) + + # Layer 16: LayerNorm on CLS token + if self.start_layer < 16 <= self.end_layer: + x = self.layer16(x[:, 0]) + + # Layer 17: Classification + if self.start_layer < 17 <= self.end_layer: + x = self.layer17(x) + + return x diff --git a/other/DCSL/src/model/__init__.py b/other/DCSL/src/model/__init__.py index b0304dd..b1922ce 100644 --- a/other/DCSL/src/model/__init__.py +++ b/other/DCSL/src/model/__init__.py @@ -4,3 +4,5 @@ from .VGG16_MNIST import * from .ViT_CIFAR10 import * from .ViT_MNIST import * +from .BERT_AGNEWS import * +from .KWT_SPEECHCOMMANDS import *