From 8fd74da30638a8a05309fcb4dbe89ef7a7a97a5f Mon Sep 17 00:00:00 2001 From: nanhhao04 Date: Thu, 12 Mar 2026 15:14:02 +0700 Subject: [PATCH 1/2] add 2LS method --- other/2LS/client.py | 61 +++ other/2LS/config.yaml | 41 ++ other/2LS/readme.md | 25 ++ other/2LS/server.py | 32 ++ other/2LS/src/Log.py | 53 +++ other/2LS/src/RpcClient.py | 148 +++++++ other/2LS/src/Scheduler.py | 230 +++++++++++ other/2LS/src/Server.py | 460 +++++++++++++++++++++ other/2LS/src/Utils.py | 79 ++++ other/2LS/src/Validation.py | 65 +++ other/2LS/src/model/BERT_EMOTION.py | 428 +++++++++++++++++++ other/2LS/src/model/MobileNetv1_CIFAR10.py | 185 +++++++++ other/2LS/src/model/MobileNetv1_MNIST.py | 185 +++++++++ other/2LS/src/model/VGG16_CIFAR10.py | 230 +++++++++++ other/2LS/src/model/VGG16_MNIST.py | 226 ++++++++++ other/2LS/src/model/ViT_CIFAR10.py | 116 ++++++ other/2LS/src/model/ViT_MNIST.py | 116 ++++++ other/2LS/src/model/__init__.py | 6 + 18 files changed, 2686 insertions(+) create mode 100644 other/2LS/client.py create mode 100644 other/2LS/config.yaml create mode 100644 other/2LS/readme.md create mode 100644 other/2LS/server.py create mode 100644 other/2LS/src/Log.py create mode 100644 other/2LS/src/RpcClient.py create mode 100644 other/2LS/src/Scheduler.py create mode 100644 other/2LS/src/Server.py create mode 100644 other/2LS/src/Utils.py create mode 100644 other/2LS/src/Validation.py create mode 100644 other/2LS/src/model/BERT_EMOTION.py create mode 100644 other/2LS/src/model/MobileNetv1_CIFAR10.py create mode 100644 other/2LS/src/model/MobileNetv1_MNIST.py create mode 100644 other/2LS/src/model/VGG16_CIFAR10.py create mode 100644 other/2LS/src/model/VGG16_MNIST.py create mode 100644 other/2LS/src/model/ViT_CIFAR10.py create mode 100644 other/2LS/src/model/ViT_MNIST.py create mode 100644 other/2LS/src/model/__init__.py diff --git a/other/2LS/client.py b/other/2LS/client.py new file mode 100644 index 0000000..7e0f21d --- /dev/null +++ b/other/2LS/client.py @@ -0,0 +1,61 @@ +import pika +import uuid +import argparse +import yaml + +import torch + +import src.Log +from src.RpcClient import RpcClient +from src.Scheduler import Scheduler + +parser = argparse.ArgumentParser(description="Split learning framework") +parser.add_argument('--layer_id', type=int, required=True, help='ID of layer, start from 1') +parser.add_argument('--device', type=str, required=False, help='Device of client') +# add new argument +parser.add_argument('--idx', type=int, required=True, help='index of client') +parser.add_argument('--incluster', type=int, required=False, default=0, help='In-cluster ID') +parser.add_argument('--outcluster', type=int, required=False, default=0, help='Out-cluster ID') +args = parser.parse_args() + +with open('config.yaml', 'r') as file: + config = yaml.safe_load(file) + +client_id = uuid.uuid4() +address = config["rabbit"]["address"] +username = config["rabbit"]["username"] +password = config["rabbit"]["password"] +virtual_host = config["rabbit"]["virtual-host"] + +device = None +if args.device is None: + if torch.cuda.is_available(): + device = "cuda" + print(f"Using device: {torch.cuda.get_device_name(device)}") + else: + device = "cpu" + print(f"Using device: CPU") +else: + device = args.device + print(f"Using device: {device}") + +credentials = pika.PlainCredentials(username, password) +connection = pika.BlockingConnection(pika.ConnectionParameters(address, 5672, f'{virtual_host}', credentials)) +channel = connection.channel() + +in_cluster_id = args.incluster +out_cluster_id = args.outcluster +idx = args.idx + +if __name__ == "__main__": + src.Log.print_with_color("[>>>] Client sending registration message to server...", "red") + + data = {"action": "REGISTER", "client_id": client_id, "idx": idx, "layer_id": args.layer_id, + "in_cluster_id": in_cluster_id, "out_cluster_id": out_cluster_id, "message": "Hello from Client!"} + + scheduler = Scheduler(client_id, args.layer_id, channel, device, in_cluster_id=in_cluster_id, idx=idx) + + client = RpcClient(client_id, args.layer_id, channel, scheduler.train_on_device, device) + client.send_to_server(data) + client.wait_response() + diff --git a/other/2LS/config.yaml b/other/2LS/config.yaml new file mode 100644 index 0000000..a2b916e --- /dev/null +++ b/other/2LS/config.yaml @@ -0,0 +1,41 @@ +name: Split Learning +server: + local-round: 1 + global-round: 1 + + clients: + - 1 + - 1 + no-cluster: + cut-layers: [1] + manual-cluster: + num-cluster: 1 + cut-layers: [1] + + model: VGG16 + data-name: CIFAR10 + parameters: + load: False + save: False + validation: False + data-distribution: + non-iid: False + num-sample: 5000 + num-label: 10 + dirichlet: + alpha: 1 + random-seed: 1 + +rabbit: + address: 127.0.0.1 + username: admin + password: admin + virtual-host: / + +log_path: . +debug_mode: True + +learning: + learning-rate: 0.01 + momentum: 0.5 + batch-size: 32 diff --git a/other/2LS/readme.md b/other/2LS/readme.md new file mode 100644 index 0000000..1d69d85 --- /dev/null +++ b/other/2LS/readme.md @@ -0,0 +1,25 @@ +# SERVER + +``` +python3 server.py + +# SPLIT SERVER (Layer 2) +python3 client.py --layer_id 2 --idx 0 --incluster 0 --outcluster 0 +python3 client.py --layer_id 2 --idx 1 --incluster 0 --outcluster 0 +python3 client.py --layer_id 2 --idx 2 --incluster 1 --outcluster 0 + +# OUT-CLUSTER 0 - Layer 1 +python3 client.py --layer_id 1 --idx 0 --incluster 0 --outcluster 0 +python3 client.py --layer_id 1 --idx 1 --incluster 0 --outcluster 0 +python3 client.py --layer_id 1 --idx 2 --incluster 1 --outcluster 0 + + +# OUT-CLUSTER 1 - Layer 1 +python3 client.py --layer_id 1 --idx 0 --incluster 0 --outcluster 1 +python3 client.py --layer_id 1 --idx 1 --incluster 0 --outcluster 1 +python3 client.py --layer_id 1 --idx 2 --incluster 1 --outcluster 1 + +# OUT-CLUSTER 2 - Layer 1 +python3 client.py --layer_id 1 --idx 0 --incluster 0 --outcluster 2 +python3 client.py --layer_id 1 --idx 1 --incluster 0 --outcluster 2 +python3 client.py --layer_id 1 --idx 2 --incluster 1 --outcluster 2 \ No newline at end of file diff --git a/other/2LS/server.py b/other/2LS/server.py new file mode 100644 index 0000000..4fa6369 --- /dev/null +++ b/other/2LS/server.py @@ -0,0 +1,32 @@ +import argparse +import sys +import signal +from src.Server import Server +from src.Utils import delete_old_queues +import src.Log +import yaml + +parser = argparse.ArgumentParser(description="Split learning framework with controller.") + +args = parser.parse_args() + +with open('config.yaml') as file: + config = yaml.safe_load(file) +address = config["rabbit"]["address"] +username = config["rabbit"]["username"] +password = config["rabbit"]["password"] +virtual_host = config["rabbit"]["virtual-host"] + + +def signal_handler(sig, frame): + print("\nCatch stop signal Ctrl+C. Stop the program.") + delete_old_queues(address, username, password, virtual_host) + sys.exit(0) + + +if __name__ == "__main__": + signal.signal(signal.SIGINT, signal_handler) + delete_old_queues(address, username, password, virtual_host) + server = Server(config) + server.start() + src.Log.print_with_color("Ok, ready!", "green") diff --git a/other/2LS/src/Log.py b/other/2LS/src/Log.py new file mode 100644 index 0000000..640ad13 --- /dev/null +++ b/other/2LS/src/Log.py @@ -0,0 +1,53 @@ +import logging + + +class Colors: + COLORS = { + "header": '\033[95m', + "blue": '\033[94m', + "green": '\033[92m', + "yellow": '\033[93m', + "red": '\033[91m', + "end": '\033[0m' + } + + +class Logger: + def __init__(self, log_path, debug_mode=False): + # Thiết lập logger với tên "my_logger" + self.logger = logging.getLogger("my_logger") + self.logger.setLevel(logging.DEBUG) # Mức log + self.debug_mode = debug_mode + + # Tạo file handler để ghi log vào file + file_handler = logging.FileHandler(log_path) + file_handler.setLevel(logging.DEBUG) + + # Định dạng log + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + file_handler.setFormatter(formatter) + + # Gắn file handler vào logger + self.logger.addHandler(file_handler) + + def log_info(self, message): + print(f"[INFO] {message}") + self.logger.info(message) + + def log_warning(self, message): + print_with_color(f"[WARN] {message}", "yellow") + self.logger.warning(message) + + def log_error(self, message): + print_with_color(f"[ERROR] {message}", "red") + self.logger.error(message) + + def log_debug(self, message): + if self.debug_mode: + print_with_color(f"[DEBUG] {message}", "green") + self.logger.debug(message) + + +def print_with_color(text, color): + color_code = Colors.COLORS.get(color.lower(), Colors.COLORS["end"]) + print(f"{color_code}{text}{Colors.COLORS['end']}") diff --git a/other/2LS/src/RpcClient.py b/other/2LS/src/RpcClient.py new file mode 100644 index 0000000..63825e2 --- /dev/null +++ b/other/2LS/src/RpcClient.py @@ -0,0 +1,148 @@ +import time +import pickle +import random +import copy +import torchvision +import torchvision.transforms as transforms + +from collections import defaultdict +from tqdm import tqdm + +import src.Log +from src.model import * + + +class RpcClient: + def __init__(self, client_id, layer_id, channel, train_func, device): + self.client_id = client_id + self.layer_id = layer_id + self.channel = channel + self.train_func = train_func + self.device = device + + self.response = None + self.model = None + self.label_count = None + + self.train_set = None + self.label_to_indices = None + + def wait_response(self): + status = True + reply_queue_name = f'reply_{self.client_id}' + self.channel.queue_declare(reply_queue_name, durable=False) + while status: + method_frame, header_frame, body = self.channel.basic_get(queue=reply_queue_name, auto_ack=True) + if body: + status = self.response_message(body) + time.sleep(0.5) + + def response_message(self, body): + self.response = pickle.loads(body) + src.Log.print_with_color(f"[<<<] Client received: {self.response.get('message', 'No message')}", "blue") + action = self.response["action"] + state_dict = self.response["parameters"] + + if action == "START": + model_name = self.response["model_name"] + cut_layers = self.response['layers'] + label_count = self.response['label_count'] + data_name = self.response["data_name"] + local_round = self.response["local_round"] + + if self.label_count is None: + self.label_count = label_count + + if self.label_count is not None: + src.Log.print_with_color(f"Label distribution of client: {self.label_count}", "yellow") + + # Load training dataset + if self.layer_id == 1 and data_name and not self.train_set and not self.label_to_indices: + if data_name == "MNIST": + transform_train = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + self.train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, + transform=transform_train) + + elif data_name == "CIFAR10": + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + self.train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, + transform=transform_train) + else: + self.train_set = None + raise ValueError(f"Data name '{data_name}' is not valid.") + + self.label_to_indices = defaultdict(list) + for idx, (_, label) in tqdm(enumerate(self.train_set)): + self.label_to_indices[int(label)].append(idx) + + # Load model + if self.model is None: + + klass = globals()[f'{model_name}_{data_name}'] + + if cut_layers[1] == -1: + self.model = klass(start_layer=cut_layers[0]) + else: + self.model = klass(start_layer=cut_layers[0], end_layer=cut_layers[1]) + + self.model.to(self.device) + + batch_size = self.response["batch_size"] + lr = self.response["lr"] + momentum = self.response["momentum"] + out_cluster_id = self.response.get("out_cluster_id", -1) + + # Read parameters and load to model + if state_dict: + self.model.load_state_dict(state_dict) + + # Start training + if self.layer_id == 1: + selected_indices = [] + for label, count in enumerate(self.label_count): + selected_indices.extend(random.sample(self.label_to_indices[label], count)) + + subset = torch.utils.data.Subset(self.train_set, selected_indices) + train_loader = torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=True) + + result, size = self.train_func(self.model, lr, momentum, train_loader, local_round=local_round, out_cluster_id=out_cluster_id) + + else: + # Layer 2 handles data asynchronously from Layer 1, no local_round limit + result, size = self.train_func(self.model, lr, momentum, out_cluster_id=out_cluster_id) + + # Stop training, then send parameters to server + model_state_dict = copy.deepcopy(self.model.state_dict()) + if self.device != "cpu": + for key in model_state_dict: + model_state_dict[key] = model_state_dict[key].to('cpu') + + data = {"action": "UPDATE", "client_id": self.client_id, "layer_id": self.layer_id, + "result": result, "size": size, + "message": "Sent parameters to Server", "parameters": model_state_dict} + src.Log.print_with_color("[>>>] Client sent parameters to server", "red") + self.send_to_server(data) + return True + elif action == "PAUSE": + return True + elif action == "STOP": + return False + + + def send_to_server(self, message): + self.response = None + + self.channel.queue_declare('rpc_queue', durable=False) + self.channel.basic_publish(exchange='', + routing_key='rpc_queue', + body=pickle.dumps(message)) + + return self.response diff --git a/other/2LS/src/Scheduler.py b/other/2LS/src/Scheduler.py new file mode 100644 index 0000000..b7c2186 --- /dev/null +++ b/other/2LS/src/Scheduler.py @@ -0,0 +1,230 @@ +import time +import pickle + +import torch +import torch.optim as optim +import torch.nn as nn + +import src.Log +from tqdm import tqdm + +class Scheduler: + def __init__(self, client_id, layer_id, channel, device, in_cluster_id=-1, idx=-1): + self.client_id = client_id + self.layer_id = layer_id + self.channel = channel + self.device = device + self.in_cluster_id = in_cluster_id + self.idx = idx + self.data_count = 0 + + def send_intermediate_output(self, output, labels, trace, out_cluster_id=-1): + + forward_queue_name = f'intermediate_queue_{out_cluster_id}_{self.in_cluster_id}_{self.idx}' + + self.channel.queue_declare(forward_queue_name, durable=False) + + if trace: + trace.append(self.client_id) + message = pickle.dumps( + {"data": output.detach().cpu().numpy(), "label": labels.cpu().numpy(), + "trace": trace} + ) + else: + message = pickle.dumps( + {"data": output.detach().cpu().numpy(), "label": labels.cpu().numpy(), + "trace": [self.client_id]} + ) + + self.channel.basic_publish( + exchange='', + routing_key=forward_queue_name, + body=message + ) + + def send_gradient(self, gradient, trace): + to_client_id = trace[-1] + trace.pop(-1) + backward_queue_name = f'gradient_queue_{self.layer_id - 1}_{to_client_id}' + self.channel.queue_declare(queue=backward_queue_name, durable=False) + + message = pickle.dumps( + {"data": gradient.detach().cpu().numpy(), "trace": trace}) + + self.channel.basic_publish( + exchange='', + routing_key=backward_queue_name, + body=message + ) + + def send_to_server(self, message): + self.channel.queue_declare('rpc_queue', durable=False) + self.channel.basic_publish(exchange='', + routing_key='rpc_queue', + body=pickle.dumps(message)) + + def train_on_first_layer(self, model, lr, momentum, train_loader=None, local_round=1, out_cluster_id=-1): + optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) + + backward_queue_name = f'gradient_queue_{self.layer_id}_{self.client_id}' + self.channel.queue_declare(queue=backward_queue_name, durable=False) + self.channel.basic_qos(prefetch_count=1) + + model.to(self.device) + + for i in range(local_round): + data_iter = iter(train_loader) + src.Log.print_with_color(f'Forward epoch {i+1}/{local_round}', 'green') + + with tqdm(total=len(train_loader), desc="Processing", unit="step") as pbar: + while True: + try: + training_data, labels = next(data_iter) + training_data = training_data.to(self.device) + intermediate_output = model(training_data) + intermediate_output = intermediate_output.detach().requires_grad_(True) + + self.data_count += 1 + + self.send_intermediate_output(intermediate_output, labels, trace=None, out_cluster_id=out_cluster_id) + while True: + model.train() + optimizer.zero_grad() + method_frame, header_frame, body = self.channel.basic_get(queue=backward_queue_name, auto_ack=True) + if method_frame and body: + received_data = pickle.loads(body) + gradient_numpy = received_data["data"] + gradient = torch.tensor(gradient_numpy).to(self.device) + + output = model(training_data) + output.backward(gradient=gradient) + optimizer.step() + break + else: + # Check for PAUSE from server (Early Termination) + broadcast_queue_name = f'reply_{self.client_id}' + method_f, header_f, body_f = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + if body_f: + received_data = pickle.loads(body_f) + src.Log.print_with_color(f"[<<<] Received message during gradient wait: {received_data}", "blue") + if received_data.get("action") == "PAUSE": + return True + time.sleep(0.5) + continue + + pbar.update(1) + + except StopIteration: + break + + notify_data = {"action": "NOTIFY", "client_id": self.client_id, "layer_id": self.layer_id, + "message": "Finish training!"} + + # Finish epoch training, send notify to server + src.Log.print_with_color("[>>>] Finish training!", "red") + self.send_to_server(notify_data) + + broadcast_queue_name = f'reply_{self.client_id}' + while True: # Wait for broadcast + method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + if body: + received_data = pickle.loads(body) + src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue") + if received_data["action"] == "PAUSE": + return True + time.sleep(0.5) + + def train_on_last_layer(self, model, lr, momentum, out_cluster_id=-1): + + optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) + result = True + + criterion = nn.CrossEntropyLoss() + + forward_queue_name = f'intermediate_queue_{out_cluster_id}_{self.in_cluster_id}_{self.idx}' + + self.channel.queue_declare(queue=forward_queue_name, durable=False) + self.channel.basic_qos(prefetch_count=1) + print('Waiting for intermediate output. To exit press CTRL+C') + model.to(self.device) + infor_data = [] + tensor_data = torch.tensor([]) + tensor_label = torch.tensor([]) + count = 0 + + while True: + + method_frame, header_frame, body = self.channel.basic_get(queue=forward_queue_name, auto_ack=True) + if method_frame and body: + received_data = pickle.loads(body) + intermediate_output_numpy = received_data["data"] + labels_numpy = received_data["label"] + trace = received_data["trace"] + + labels = torch.tensor(labels_numpy) + intermediate_output = torch.tensor(intermediate_output_numpy, requires_grad=True) + + infor_data.append((trace, intermediate_output.size(0))) + tensor_data = torch.cat((tensor_data, intermediate_output), dim=0) + tensor_label = torch.cat((tensor_label, labels), dim=0) + + self.data_count += 1 + count += 1 + else: + # Check for PAUSE from server (Early Termination) + broadcast_queue_name = f'reply_{self.client_id}' + method_f, header_f, body_f = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + if body_f: + received_data = pickle.loads(body_f) + src.Log.print_with_color(f"[<<<] Received message during idle wait: {received_data}", "blue") + if received_data.get("action") == "PAUSE": + return True + + if count == 1: + model.train() + optimizer.zero_grad() + tensor_data = tensor_data.to(self.device) + tensor_data.retain_grad() + tensor_label = tensor_label.to(self.device) + + output = model(tensor_data) + loss = criterion(output, tensor_label.long()) + print(f"Loss: {loss.item()}") + + if torch.isnan(loss).any(): + src.Log.print_with_color("NaN detected in loss", "yellow") + result = False + + loss.backward() + + optimizer.step() + gradient = tensor_data.grad + + for (trace, size) in infor_data: + grad, new_gradient = gradient.split([size, gradient.size(0) - size], dim=0) + gradient = new_gradient + + self.send_gradient(grad, trace) + + infor_data = [] + tensor_data = torch.tensor([]) + tensor_label = torch.tensor([]) + count = 0 + + else: + broadcast_queue_name = f'reply_{self.client_id}' + method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + if body: + received_data = pickle.loads(body) + src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue") + if received_data["action"] == "PAUSE": + return result + + def train_on_device(self, model, lr, momentum, train_loader=None, local_round=None, out_cluster_id=-1): + self.data_count = 0 + if self.layer_id == 1: + result = self.train_on_first_layer(model, lr, momentum, train_loader, local_round, out_cluster_id=out_cluster_id) + else: + result = self.train_on_last_layer(model, lr, momentum, out_cluster_id=out_cluster_id) + + return result, self.data_count \ No newline at end of file diff --git a/other/2LS/src/Server.py b/other/2LS/src/Server.py new file mode 100644 index 0000000..20b371f --- /dev/null +++ b/other/2LS/src/Server.py @@ -0,0 +1,460 @@ +import os +import random +import pika +import pickle +import sys +import numpy as np +import copy +import src.Log +import src.Utils +import src.Validation + +from src.model import * + + +class Server: + def __init__(self, config): + # RabbitMQ + address = config["rabbit"]["address"] + username = config["rabbit"]["username"] + password = config["rabbit"]["password"] + virtual_host = config["rabbit"]["virtual-host"] + + self.partition = config["server"]["manual-cluster"] + + self.model_name = config["server"]["model"] + self.data_name = config["server"]["data-name"] + self.total_clients = config["server"]["clients"] + self.list_cut_layers = config["server"]["manual-cluster"]["cut-layers"] + self.global_round = config["server"]["global-round"] + self.local_round = config["server"]["local-round"] + self.round = self.global_round + self.validation = config["server"]["validation"] + + self.is_clustered = True + + # clustering + self.out_cluster_models = {} # {out_idx: state_dict} + self.out_cluster_order = [] # Shuffled of out-clusters + self.current_out_cluster_cursor = 0 # Pointer current out-cluster + self.current_out_cluster_idx = 0 + self.finished_clients_in_cluster = {} # {(out_idx, in_idx): count} + self.finished_upper_clients_count = {} # {out_idx: count} — Phase 2 only + + # FedAsync: track thứ tự in-cluster đến cho mỗi out-cluster + self.incluster_fedasync_order = {} # {out_idx: [in_idx_first, in_idx_second, ...]} + self.incluster_l1_avg = {} # {(out_idx, in_idx): state_dict} — saved L1 FedAvg result + self.incluster_l2_finished = {} # {(out_idx, in_idx): count} — L2 done per in-cluster + + # Clients + self.batch_size = config["learning"]["batch-size"] + self.lr = config["learning"]["learning-rate"] + self.momentum = config["learning"]["momentum"] + self.data_distribution = config["server"]["data-distribution"] + + # Data distribution + self.non_iid = self.data_distribution["non-iid"] + self.num_label = self.data_distribution["num-label"] + self.num_sample = self.data_distribution["num-sample"] + self.random_seed = config["server"]["random-seed"] + self.label_counts = None + self.label_ = None + + if self.random_seed: + random.seed(self.random_seed) + + log_path = config["log_path"] + + credentials = pika.PlainCredentials(username, password) + self.connection = pika.BlockingConnection( + pika.ConnectionParameters(address, 5672, f'{virtual_host}', credentials)) + self.channel = self.connection.channel() + self.channel.queue_declare(queue='rpc_queue') + + self.current_clients = 0 + self.register_clients = [0 for _ in range(len(self.total_clients))] + self.responses = {} # Save response + self.list_clients = [] + self.round_result = True + + # Model (Isolated buffers per cluster: (layer, out_idx, in_idx)) + self.global_model_parameters = {} + self.global_client_sizes = {} + + # Sequential + self.edge_device = [] + self.device_begin = [] + self.device_stop = [] + self.avg_state_dict = [] + self.upper_clients = {} # {out_cluster_id: [(cid, layer_id)]} for layer 2 clients + + self.channel.basic_qos(prefetch_count=1) + self.reply_channel = self.connection.channel() + self.channel.basic_consume(queue='rpc_queue', on_message_callback=self.on_request) + + debug_mode = config["debug_mode"] + self.logger = src.Log.Logger(f"{log_path}/app.log", debug_mode) + src.Log.print_with_color(f"Application start. Server is waiting for {self.total_clients} clients.", "green") + self.logger.log_info(f"Application start. Server is waiting for {self.total_clients} clients.") + + + def distribution(self): + if self.non_iid: + label_distribution = np.array([[0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], + [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], + [0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], + [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], + [0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], + [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], + ]) + + self.label_counts = (label_distribution * self.num_sample).astype(int) + self.label_ = copy.deepcopy(self.label_counts) + self.label_ = self.label_.tolist() + else: + self.label_counts = np.full((self.total_clients[0], self.num_label), self.num_sample // self.num_label) + self.label_ = copy.deepcopy(self.label_counts) + self.label_ = self.label_.tolist() + + def on_request(self, ch, method, props, body): + message = pickle.loads(body) + routing_key = props.reply_to + action = message["action"] + client_id = message["client_id"] + layer_id = int(message["layer_id"]) + + self.responses[routing_key] = message + ch.basic_ack(delivery_tag=method.delivery_tag) # Ack immediately — always + + if action == "REGISTER": + in_cluster_id = message["in_cluster_id"] + out_cluster_id = message["out_cluster_id"] + idx = message.get("idx", -1) + cid_str = str(client_id) + if (cid_str, layer_id, in_cluster_id, out_cluster_id, idx) not in self.list_clients: + self.list_clients.append((cid_str, layer_id, in_cluster_id, out_cluster_id, idx)) + + src.Log.print_with_color(f"[<<<] Received message from client: {message}", "blue") + self.register_clients[layer_id - 1] += 1 + + if self.register_clients == self.total_clients: + self.distribution() + + filepath = f'{self.model_name}_{self.data_name}.pth' + initial_sd = torch.load(filepath, weights_only=True) if os.path.exists(filepath) else {} + + # node[3] là out_cluster_id + unique_out_clusters = sorted(list(set(node[3] for node in self.list_clients))) # thứ tự 0 1 2 + for o_idx in unique_out_clusters: + self.out_cluster_models[o_idx] = copy.deepcopy(initial_sd) + + # Initialize Shuffled Out-cluster Order + self.out_cluster_order = unique_out_clusters + random.shuffle(self.out_cluster_order) + self.current_out_cluster_cursor = 0 + self.current_out_cluster_idx = self.out_cluster_order[0] + + src.Log.print_with_color(f"All clients connected. Shuffled Out-cluster order: {self.out_cluster_order}", + "green") + src.Log.print_with_color("Hierarchical structure initialized from predefined IDs.", "green") + self.logger.log_info(f"Start training round {self.global_round - self.round + 1}") + self.notify_clients(register=False) + + elif action == "NOTIFY": + src.Log.print_with_color(f"[<<<] Received message from client: {message}", "blue") + + message_pause = {"action": "PAUSE", + "message": "Pause training and please send your parameters", + "parameters": None} + + node = next((n for n in self.list_clients if n[0] == str(client_id)), None) + if int(layer_id) > 1: + out_idx = self.current_out_cluster_idx + in_idx = node[2] if node else 0 + elif node: + out_idx, in_idx = node[3], node[2] + else: + out_idx, in_idx = 0, 0 + + if out_idx == self.current_out_cluster_idx: + key = (out_idx, in_idx) + if key not in self.finished_clients_in_cluster: + self.finished_clients_in_cluster[key] = 0 + self.finished_clients_in_cluster[key] += 1 + + clients_in_cluster = [n[0] for n in self.list_clients if n[3] == out_idx and n[2] == in_idx and n[1] == 1] + if self.finished_clients_in_cluster[key] == len(clients_in_cluster): + src.Log.print_with_color(f">>> In-cluster ({out_idx}, {in_idx}) finished. Requesting parameters.", "yellow") + for cid in clients_in_cluster: + self.send_to_response(cid, pickle.dumps(message_pause)) + + elif action == "UPDATE": + data_message = message["message"] + result = message["result"] + model_state_dict = message["parameters"] + client_size = message["size"] + + src.Log.print_with_color(f"[<<<] Received message from {client_id}: {data_message}", "blue") + + node = next((n for n in self.list_clients if n[0] == str(client_id)), None) + if layer_id > 1: + out_idx = self.current_out_cluster_idx + in_idx = node[2] if node else 0 + #src.Log.print_with_color(f">>> Mapping Split Server {client_id} to active Out-cluster {out_idx}", "yellow") + elif node: + out_idx, in_idx = node[3], node[2] + else: + out_idx, in_idx = 0, 0 + + key = (out_idx, in_idx) + cid_str = str(client_id) + + if layer_id == 1: + if out_idx == self.current_out_cluster_idx: + cluster_key = (layer_id, out_idx, in_idx) + if cluster_key not in self.global_model_parameters: + self.global_model_parameters[cluster_key] = [] + self.global_client_sizes[cluster_key] = [] + + self.global_model_parameters[cluster_key].append(model_state_dict) + self.global_client_sizes[cluster_key].append(client_size) + + # số client ở layer 1 của cur_o_idx + clients_in_cluster = [n[0] for n in self.list_clients if int(n[3]) == out_idx and int(n[2]) == in_idx and int(n[1]) == 1] + total_in_cluster = len(clients_in_cluster) + + # FedAvg khi in-cluster đã nhận đủ update từ tất cả client + if len(self.global_model_parameters[cluster_key]) == total_in_cluster: + src.Log.print_with_color(f">>> Sync In-cluster L1 FedAvg ({out_idx}, {in_idx})", "yellow") + in_cluster_avg_sd = src.Utils.fedavg_state_dicts(self.global_model_parameters[cluster_key], + weights=self.global_client_sizes[cluster_key]) + self.global_model_parameters[cluster_key] = [] + self.global_client_sizes[cluster_key] = [] + self.finished_clients_in_cluster[(out_idx, in_idx)] = 0 + + # Track thứ tự in-cluster đến + if out_idx not in self.incluster_fedasync_order: + self.incluster_fedasync_order[out_idx] = [] + order_list = self.incluster_fedasync_order[out_idx] + order_list.append(in_idx) + + alpha = 1.0 if len(order_list) == 1 else 0.5 + src.Log.print_with_color(f">>> In-cluster ({out_idx}, {in_idx}) arrived {'FIRST' if alpha == 1.0 else 'LATER'}. alpha={alpha}", "green") + + # Lưu L1 FedAvg result — CHƯA FedAsync, chờ L2 + self.incluster_l1_avg[(out_idx, in_idx)] = in_cluster_avg_sd + + # Pause L1 clients + message_pause_l1 = {"action": "PAUSE", "message": "In-cluster done. Waiting.", "parameters": None} + for cid in clients_in_cluster: + self.send_to_response(cid, pickle.dumps(message_pause_l1)) + + # Gửi PAUSE cho L2 clients thuộc in-cluster này + l2_clients_this_ic = [n for n in self.list_clients if int(n[1]) > 1 and int(n[2]) == in_idx] + + if l2_clients_this_ic: + message_pause_l2 = {"action": "PAUSE", "message": f"Send L2 parameters for in-cluster {in_idx}.", "parameters": None} + seen = set() + for node in reversed(l2_clients_this_ic): + role_key = (node[1], node[2], node[3], node[4]) + if role_key not in seen: + seen.add(role_key) + src.Log.print_with_color(f">>> Sending PAUSE to L2 client {node[0]} for in-cluster {in_idx}", "yellow") + self.send_to_response(node[0], pickle.dumps(message_pause_l2)) + else: + # Không có L2 → FedAsync L1 trực tiếp (paper: model chỉ có L1) + src.Log.print_with_color(f">>> No L2 for in-cluster ({out_idx}, {in_idx}). FedAsync L1 only, alpha={alpha}", "green") + self.fedasync_aggregate(out_idx, in_cluster_avg_sd, alpha=alpha) + self.incluster_l1_avg.pop((out_idx, in_idx), None) + self.check_out_cluster_completion(out_idx) + + else: + message_pause = {"action": "PAUSE", "message": "Round mismatch. Waiting...", "parameters": None} + self.send_to_response(cid_str, pickle.dumps(message_pause)) + + elif layer_id > 1: + # Accumulate L2 update per in-cluster + src.Log.print_with_color(f">>> Received UPDATE from Upper Layer client {client_id} (Layer {layer_id}, in-cluster {in_idx})", "yellow") + l2_key = (layer_id, out_idx, in_idx) + if l2_key not in self.global_model_parameters: + self.global_model_parameters[l2_key] = [] + self.global_client_sizes[l2_key] = [] + self.global_model_parameters[l2_key].append(model_state_dict) + self.global_client_sizes[l2_key].append(client_size) + + # Đếm L2 per in-cluster + l2_ic_key = (out_idx, in_idx) + if l2_ic_key not in self.incluster_l2_finished: + self.incluster_l2_finished[l2_ic_key] = 0 + self.incluster_l2_finished[l2_ic_key] += 1 + + # Số L2 clients (unique roles) thuộc in-cluster này + total_upper_this_ic = len(set((n[1], n[2], n[3], n[4]) for n in self.list_clients if int(n[1]) > 1 and int(n[2]) == in_idx)) + + if self.incluster_l2_finished[l2_ic_key] >= total_upper_this_ic: + # FedAvg L2 + avg_sd_l2 = src.Utils.fedavg_state_dicts(self.global_model_parameters[l2_key], + weights=self.global_client_sizes[l2_key]) + self.global_model_parameters[l2_key] = [] + self.global_client_sizes[l2_key] = [] + + # Merge L1 avg + L2 avg → full model + l1_avg = self.incluster_l1_avg.pop(l2_ic_key, {}) + merged_sd = {} + merged_sd.update(l1_avg) + merged_sd.update(avg_sd_l2) + + # Lấy alpha theo thứ tự in-cluster đến + order_list = self.incluster_fedasync_order.get(out_idx, []) + arrival_pos = order_list.index(in_idx) if in_idx in order_list else 0 + alpha = 1.0 if arrival_pos == 0 else 0.5 + + src.Log.print_with_color(f">>> In-cluster ({out_idx}, {in_idx}) L1+L2 merged. FedAsync alpha={alpha} (paper Alg.1)", "green") + self.fedasync_aggregate(out_idx, merged_sd, alpha=alpha) + + self.check_out_cluster_completion(out_idx) + + def check_out_cluster_completion(self, out_idx): + # Kiểm tra tất cả in-cluster đã FedAsync xong + all_l1_in_oc = set(int(n[2]) for n in self.list_clients if int(n[1]) == 1 and int(n[3]) == out_idx) + order_list = self.incluster_fedasync_order.get(out_idx, []) + is_done = len(order_list) >= len(all_l1_in_oc) if all_l1_in_oc else True + + # Kiểm tra L2 per in-cluster: tất cả in-cluster phải hoàn thành L2 + all_l2_done = True + for ic_idx in all_l1_in_oc: + l2_count_for_ic = len(set((n[1], n[2], n[3], n[4]) for n in self.list_clients if int(n[1]) > 1 and int(n[2]) == ic_idx)) + if l2_count_for_ic > 0: + finished_l2_ic = self.incluster_l2_finished.get((out_idx, ic_idx), 0) + if finished_l2_ic < l2_count_for_ic: + all_l2_done = False + break + + if is_done and all_l2_done: + src.Log.print_with_color(f">>> Out-cluster {out_idx} FULLY completed (L1 & L2+).", "green") + + # Validation: tổng hợp mô hình và in accuracy + state_dict_full = self.out_cluster_models[out_idx] + if len(state_dict_full) > 0: + src.Log.print_with_color(f">>> Running validation for Out-cluster {out_idx}...", "yellow") + src.Validation.test(self.model_name, self.data_name, state_dict_full, self.logger) + + # Reset cho out-cluster này + self.incluster_fedasync_order[out_idx] = [] + self.finished_upper_clients_count[out_idx] = 0 + # Reset L2 per in-cluster counters + for ic_idx in all_l1_in_oc: + self.incluster_l2_finished.pop((out_idx, ic_idx), None) + self.incluster_l1_avg.pop((out_idx, ic_idx), None) + + # chuyển sang outcluster tiếp + self.current_out_cluster_cursor += 1 + + if self.current_out_cluster_cursor >= len(self.out_cluster_order): + # xong 1 round ( chạy qua hết outcluster) + self.round -= 1 + if self.round <= 0: + src.Log.print_with_color(">>> All global rounds completed.", "green") + state_dict_full = self.out_cluster_models[out_idx] + torch.save(state_dict_full, f'{self.model_name}_{self.data_name}.pth') + src.Log.print_with_color(">>> Server training process total completion.", "green") + return + + # next round + self.current_out_cluster_cursor = 0 + random.shuffle(self.out_cluster_order) + src.Log.print_with_color( + f">>> New Global Round. Shuffled Out-cluster order: {self.out_cluster_order}", "green") + + # Set next out-cluster + next_out_idx = self.out_cluster_order[self.current_out_cluster_cursor] + self.out_cluster_models[next_out_idx] = copy.deepcopy(self.out_cluster_models[out_idx]) + self.current_out_cluster_idx = next_out_idx + + # Start next Out-cluster + src.Log.print_with_color(f">>> Moving to Out-cluster {self.current_out_cluster_idx}", "yellow") + self.notify_clients(register=False) + + # FedAsync: W_new = (1-alpha)*W_old + alpha*W_received + def fedasync_aggregate(self, out_idx, in_cluster_sd, alpha=1.0): + target_sd = self.out_cluster_models[out_idx] + for key in in_cluster_sd.keys(): + if key in target_sd: + target_sd[key] = (1.0 - alpha) * target_sd[key].float() + alpha * in_cluster_sd[key].float() + target_sd[key] = target_sd[key].to(in_cluster_sd[key].dtype) + else: + # Key chưa tồn tại (model khởi tạo rỗng) → thêm trực tiếp + target_sd[key] = in_cluster_sd[key].clone() + src.Log.print_with_color(f">>> FedAsync Out-cluster {out_idx} updated (alpha={alpha}).", "green") + + def notify_clients(self, start=True, register=True): + if not start: + for node in self.list_clients: + self.send_to_response(node[0], pickle.dumps({"action": "STOP", "message": "Stop!", "parameters": None})) + return + + if register: + pass + else: + o_idx = self.current_out_cluster_idx + full_sd = self.out_cluster_models[o_idx] + cut = int(self.list_cut_layers[0][0] if isinstance(self.list_cut_layers[0], list) else self.list_cut_layers[0]) + klass = globals()[f'{self.model_name}_{self.data_name}'] + + src.Log.print_with_color(f">>> Starting Training Round for Out-cluster {o_idx}...", "red") + + # Notify cho những thằng ở o_idx hiện tại thoi + for role in set(n[1:] for n in self.list_clients): + layer_id, in_idx, out_idx, idx = role + if not ((layer_id == 1 and out_idx == o_idx) or layer_id > 1): + continue + + client_id = next(n[0] for n in reversed(self.list_clients) if n[1:] == role) + layers = [0, cut] if layer_id == 1 else [cut, -1] + model = klass(end_layer=cut) if layer_id == 1 else klass(start_layer=cut) + + state_dict = model.state_dict() + if len(full_sd) > 0: + for key in state_dict.keys(): + if key in full_sd: + state_dict[key] = full_sd[key] + + label = self.label_[idx] if (layer_id == 1 and self.label_) else [] + + response = {"action": "START", "message": "Training Start", "parameters": state_dict, + "layers": layers, "model_name": self.model_name, "data_name": self.data_name, + "batch_size": self.batch_size, "lr": self.lr, "momentum": self.momentum, + "label_count": label, "local_round": self.local_round, "cluster": in_idx, + "out_cluster_id": o_idx} + + src.Log.print_with_color(f">>> Notifying client {client_id} (layer {layer_id}) for out-cluster {o_idx}, in-cluster {in_idx}", "yellow") + self.send_to_response(client_id, pickle.dumps(response)) + + def start(self): + self.channel.start_consuming() + + def send_to_response(self, client_id, message): + reply_queue_name = f'reply_{client_id}' + self.reply_channel.queue_declare(reply_queue_name, durable=False) + src.Log.print_with_color(f"[>>>] Sent notification to client {client_id}", "red") + self.reply_channel.basic_publish(exchange='', routing_key=reply_queue_name, body=message) + + def avg_all_parameters(self): + layer_sizes = self.global_client_sizes + layer_params = self.global_model_parameters + for layer_idx, list_state_dicts in enumerate(layer_params): + list_sizes = layer_sizes[layer_idx] + if not list_state_dicts or not list_sizes: + self.avg_state_dict.append({}) + continue + avg_sd = src.Utils.fedavg_state_dicts(list_state_dicts, weights=list_sizes) + self.avg_state_dict.append(avg_sd) + + def concatenate_and_avg_clusters(self): + full_dict = {} + for sd in self.avg_state_dict: + full_dict.update(copy.deepcopy(sd)) + return full_dict \ No newline at end of file diff --git a/other/2LS/src/Utils.py b/other/2LS/src/Utils.py new file mode 100644 index 0000000..dc42a72 --- /dev/null +++ b/other/2LS/src/Utils.py @@ -0,0 +1,79 @@ +import numpy as np +import random +import pika +import torch + +from requests.auth import HTTPBasicAuth +import requests + + +def delete_old_queues(address, username, password, virtual_host): + url = f'http://{address}:15672/api/queues' + response = requests.get(url, auth=HTTPBasicAuth(username, password)) + + if response.status_code == 200: + queues = response.json() + + credentials = pika.PlainCredentials(username, password) + connection = pika.BlockingConnection(pika.ConnectionParameters(address, 5672, f'{virtual_host}', credentials)) + http_channel = connection.channel() + + for queue in queues: + queue_name = queue['name'] + if queue_name.startswith("reply") or queue_name.startswith("intermediate_queue") or queue_name.startswith( + "gradient_queue") or queue_name.startswith("rpc_queue"): + + http_channel.queue_delete(queue=queue_name) + + else: + http_channel.queue_purge(queue=queue_name) + + connection.close() + return True + else: + return False + +def fedavg_state_dicts(state_dicts, weights = None): + """ + Trung bình (FedAvg) một list các state_dict. + - state_dicts: list các dict {param_name: tensor} + - weights: list trọng số tương ứng (mặc định None nghĩa là mỗi model weight=1) + Trả về một dict {param_name: tensor_avg} + """ + num = len(state_dicts) + if num == 0: + raise ValueError("fedavg_state_dicts: không có state_dict nào để trung bình.") + + if weights is None: + weights = [1.0] * num + total_w = sum(weights) + + # Tập hợp tất cả key + all_keys = set().union(*(sd.keys() for sd in state_dicts)) + avg_dict = {} + + for key in all_keys: + # gom tensor + weight, xử lý NaN + acc = None + for sd, w in zip(state_dicts, weights): + if key not in sd: + continue + t = sd[key].float() + if torch.isnan(t).any(): + t = torch.nan_to_num(t) # zero-fill + t = t * w + acc = t if acc is None else acc + t + + # chia trung bình + avg = acc / total_w + + # cast về dtype gốc + orig = next(sd[key] for sd in state_dicts if key in sd) + if orig.dtype in (torch.int8, torch.int16, torch.int32, torch.int64, torch.bool): + avg = avg.round().to(orig.dtype) + else: + avg = avg.to(orig.dtype) + + avg_dict[key] = avg + + return avg_dict diff --git a/other/2LS/src/Validation.py b/other/2LS/src/Validation.py new file mode 100644 index 0000000..fb40cb9 --- /dev/null +++ b/other/2LS/src/Validation.py @@ -0,0 +1,65 @@ + +import numpy as np +import math +from tqdm import tqdm + +import torchvision +import torchvision.transforms as transforms +import torch.nn.functional as F + +from src.model import * + +def test(model_name, data_name, state_dict_full, logger): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if data_name == "MNIST": + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test) + + elif data_name == "CIFAR10": + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + else: + raise ValueError(f"Data name '{data_name}' is not valid.") + + test_loader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) + + + klass = globals()[f'{model_name}_{data_name}'] + + if klass is None: + raise ValueError(f"Class '{model_name}' does not exist.") + + model = klass() + + model.load_state_dict(state_dict_full) + model = model.to(device) + # evaluation mode + model.eval() + test_loss = 0 + correct = 0 + for data, target in tqdm(test_loader): + data = data.to(device) + target = target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability + correct += pred.eq(target.data.view_as(pred)).long().cpu().sum() + + test_loss /= len(test_loader.dataset) + accuracy = 100.0 * correct / len(test_loader.dataset) + print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), accuracy)) + + if np.isnan(test_loss) or math.isnan(test_loss) or abs(test_loss) > 10e5: + return False + else: + logger.log_info('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), accuracy)) + + return True diff --git a/other/2LS/src/model/BERT_EMOTION.py b/other/2LS/src/model/BERT_EMOTION.py new file mode 100644 index 0000000..efb4ab9 --- /dev/null +++ b/other/2LS/src/model/BERT_EMOTION.py @@ -0,0 +1,428 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +# Configuration constants +NUM_SAMPLES = 5000 +num_labels = 6 +vocab_size = 30522 +hidden_size = 768 +num_hidden_layers = 12 +num_attention_heads = 12 +intermediate_size = 3072 +max_position_embeddings = 512 +type_vocab_size = 2 +dropout_prob = 0.1 + +# BertEmbeddings class +class BertEmbeddings(nn.Module): + def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, dropout_prob): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) + self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, input_ids, token_type_ids=None): + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + +# BertSdpaSelfAttention class +class BertSdpaSelfAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, dropout_prob): + super(BertSdpaSelfAttention, self).__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(hidden_size, self.all_head_size) + self.key = nn.Linear(hidden_size, self.all_head_size) + self.value = nn.Linear(hidden_size, self.all_head_size) + self.dropout = nn.Dropout(dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask=None): + batch_size, seq_length, hidden_size = hidden_states.size() + + # Create query, key, value projections + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + # Reshape for multi-head attention + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Perform attention score calculation + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + # Scale attention scores + import math + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Apply attention mask if provided + if attention_mask is not None: + # Reshape attention_mask from [batch_size, seq_length] to [batch_size, 1, 1, seq_length] + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + # Convert 1s (valid tokens) to 0s and 0s (padding) to large negative values + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + attention_scores = attention_scores + extended_attention_mask + + # Apply softmax to get probabilities + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = self.dropout(attention_probs) + + # Calculate context by attending to values + context_layer = torch.matmul(attention_probs, value_layer) + + # Reshape back to [batch_size, seq_length, hidden_size] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + +# BertSelfOutput class +class BertSelfOutput(nn.Module): + def __init__(self, hidden_size, dropout_prob): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + +# BertAttention class +class BertAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, dropout_prob): + super(BertAttention, self).__init__() + self.self = BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob) + self.output = BertSelfOutput(hidden_size, dropout_prob) + + def forward(self, hidden_states, attention_mask=None): + self_output = self.self(hidden_states, attention_mask) + attention_output = self.output(self_output, hidden_states) + return attention_output + +# BertIntermediate class +class BertIntermediate(nn.Module): + def __init__(self, hidden_size, intermediate_size): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear(hidden_size, intermediate_size) + self.intermediate_act_fn = nn.GELU() + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + +# BertOutput class +class BertOutput(nn.Module): + def __init__(self, hidden_size, intermediate_size, dropout_prob): + super(BertOutput, self).__init__() + self.dense = nn.Linear(intermediate_size, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# BertPooler class +class BertPooler(nn.Module): + def __init__(self, hidden_size): + super(BertPooler, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + +# BertClassifier class +class BertClassifier(nn.Module): + def __init__(self, hidden_size, num_labels, dropout_prob=0.1): + super(BertClassifier, self).__init__() + self.dropout = nn.Dropout(dropout_prob) + self.classifier = nn.Linear(hidden_size, num_labels) + + def forward(self, pooled_output): + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + return logits + +# Complete BERT Model +class BERT_EMOTION(nn.Module): + def __init__(self,start_layer= 0, end_layer= 27, vocab_size=30522, hidden_size=768, intermediate_size=3072, + num_attention_heads=12, num_labels=4, max_position_embeddings=512, + type_vocab_size=2, dropout_prob=0.1, num_hidden_layers=12): + + super(BERT_EMOTION, self).__init__() + + self.start_layer = start_layer + self.end_layer = end_layer + + if (self.start_layer < 1) and (self.end_layer >= 1): + self.layer1 = BertEmbeddings(vocab_size, hidden_size , max_position_embeddings, type_vocab_size, dropout_prob) + + if (self.start_layer < 2) and (self.end_layer >= 2): + self.layer2 = nn.ModuleList([ + BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), + BertSelfOutput(hidden_size, dropout_prob) + ]) + + if (self.start_layer < 3) and (self.end_layer >= 3): + self.layer3 = nn.ModuleList([ + BertIntermediate(hidden_size, intermediate_size), + BertOutput(hidden_size, intermediate_size, dropout_prob) + ]) + + if (self.start_layer < 4) and (self.end_layer >= 4): + self.layer4 = nn.ModuleList([ + BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), + BertSelfOutput(hidden_size, dropout_prob) + ]) + + if (self.start_layer < 5) and (self.end_layer >= 5): + self.layer5 = nn.ModuleList([ + BertIntermediate(hidden_size, intermediate_size), + BertOutput(hidden_size, intermediate_size, dropout_prob) + ]) + + if (self.start_layer < 6) and (self.end_layer >= 6): + self.layer6 = nn.ModuleList([ + BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), + BertSelfOutput(hidden_size, dropout_prob) + ]) + + if (self.start_layer < 7) and (self.end_layer >= 7): + self.layer7 = nn.ModuleList([ + BertIntermediate(hidden_size, intermediate_size), + BertOutput(hidden_size, intermediate_size, dropout_prob) + ]) + + if (self.start_layer < 8) and (self.end_layer >= 8): + self.layer8 = nn.ModuleList([ + BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), + BertSelfOutput(hidden_size, dropout_prob) + ]) + + if (self.start_layer < 9) and (self.end_layer >= 9): + self.layer9 = nn.ModuleList([ + BertIntermediate(hidden_size, intermediate_size), + BertOutput(hidden_size, intermediate_size, dropout_prob) + ]) + + if (self.start_layer < 10) and (self.end_layer >= 10): + self.layer10 = nn.ModuleList([ + BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), + BertSelfOutput(hidden_size, dropout_prob) + ]) + + if (self.start_layer < 11) and (self.end_layer >= 11): + self.layer11 = nn.ModuleList([ + BertIntermediate(hidden_size, intermediate_size), + BertOutput(hidden_size, intermediate_size, dropout_prob) + ]) + + if (self.start_layer < 12) and (self.end_layer >= 12): + self.layer12 = nn.ModuleList([ + BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), + BertSelfOutput(hidden_size, dropout_prob) + ]) + + if (self.start_layer < 13) and (self.end_layer >= 13): + self.layer13 = nn.ModuleList([ + BertIntermediate(hidden_size, intermediate_size), + BertOutput(hidden_size, intermediate_size, dropout_prob) + ]) + + if (self.start_layer < 14) and (self.end_layer >= 14): + self.layer14 = nn.ModuleList([ + BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), + BertSelfOutput(hidden_size, dropout_prob) + ]) + + if (self.start_layer < 15) and (self.end_layer >= 15): + self.layer15 = nn.ModuleList([ + BertIntermediate(hidden_size, intermediate_size), + BertOutput(hidden_size, intermediate_size, dropout_prob) + ]) + + if (self.start_layer < 16) and (self.end_layer >= 16): + self.layer16 = nn.ModuleList([ + BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), + BertSelfOutput(hidden_size, dropout_prob) + ]) + + if (self.start_layer < 17) and (self.end_layer >= 17): + self.layer17 = nn.ModuleList([ + BertIntermediate(hidden_size, intermediate_size), + BertOutput(hidden_size, intermediate_size, dropout_prob) + ]) + + if (self.start_layer < 18) and (self.end_layer >= 18): + self.layer18 = nn.ModuleList([ + BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), + BertSelfOutput(hidden_size, dropout_prob) + ]) + + if (self.start_layer < 19) and (self.end_layer >= 19): + self.layer19 = nn.ModuleList([ + BertIntermediate(hidden_size, intermediate_size), + BertOutput(hidden_size, intermediate_size, dropout_prob) + ]) + + if (self.start_layer < 20) and (self.end_layer >= 20): + self.layer20 = nn.ModuleList([ + BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), + BertSelfOutput(hidden_size, dropout_prob) + ]) + + if (self.start_layer < 21) and (self.end_layer >= 21): + self.layer21 = nn.ModuleList([ + BertIntermediate(hidden_size, intermediate_size), + BertOutput(hidden_size, intermediate_size, dropout_prob) + ]) + + if (self.start_layer < 22) and (self.end_layer >= 22): + self.layer22 = nn.ModuleList([ + BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), + BertSelfOutput(hidden_size, dropout_prob) + ]) + + if (self.start_layer < 23) and (self.end_layer >= 23): + self.layer23 = nn.ModuleList([ + BertIntermediate(hidden_size, intermediate_size), + BertOutput(hidden_size, intermediate_size, dropout_prob) + ]) + + if (self.start_layer < 24) and (self.end_layer >= 24): + self.layer24 = nn.ModuleList([ + BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob), + BertSelfOutput(hidden_size, dropout_prob) + ]) + + if (self.start_layer < 25) and (self.end_layer >= 25): + self.layer25 = nn.ModuleList([ + BertIntermediate(hidden_size, intermediate_size), + BertOutput(hidden_size, intermediate_size, dropout_prob) + ]) + + if (self.start_layer < 26) and (self.end_layer >= 26): + self.layer26 = BertPooler(hidden_size) + + if (self.start_layer < 27) and (self.end_layer >= 27): + self.layer27 = BertClassifier(hidden_size, num_labels, dropout_prob) + + def forward(self, x, attention_mask=None, token_type_ids=None): + if (self.start_layer < 1) and (self.end_layer >= 1): + x = self.layer1(x, token_type_ids) + + if (self.start_layer < 2) and (self.end_layer >= 2): + x = self.layer2[1](self.layer2[0](x, attention_mask), x) + + if (self.start_layer < 3) and (self.end_layer >= 3): + x = self.layer3[1](self.layer3[0](x), x) + + if (self.start_layer < 4) and (self.end_layer >= 4): + x = self.layer4[1](self.layer4[0](x, attention_mask), x) + + if (self.start_layer < 5) and (self.end_layer >= 5): + x = self.layer5[1](self.layer5[0](x), x) + + if (self.start_layer < 6) and (self.end_layer >= 6): + x = self.layer6[1](self.layer6[0](x, attention_mask), x) + + if (self.start_layer < 7) and (self.end_layer >= 7): + x = self.layer7[1](self.layer7[0](x), x) + + if (self.start_layer < 8) and (self.end_layer >= 8): + x = self.layer8[1](self.layer8[0](x, attention_mask), x) + + if (self.start_layer < 9) and (self.end_layer >= 9): + x = self.layer9[1](self.layer9[0](x), x) + + if (self.start_layer < 10) and (self.end_layer >= 10): + x = self.layer10[1](self.layer10[0](x, attention_mask), x) + + if (self.start_layer < 11) and (self.end_layer >= 11): + x = self.layer11[1](self.layer11[0](x), x) + + if (self.start_layer < 12) and (self.end_layer >= 12): + x = self.layer12[1](self.layer12[0](x, attention_mask), x) + + if (self.start_layer < 13) and (self.end_layer >= 13): + x = self.layer13[1](self.layer13[0](x), x) + + if (self.start_layer < 14) and (self.end_layer >= 14): + x = self.layer14[1](self.layer14[0](x, attention_mask), x) + + if (self.start_layer < 15) and (self.end_layer >= 15): + x = self.layer15[1](self.layer15[0](x), x) + + if (self.start_layer < 16) and (self.end_layer >= 16): + x = self.layer16[1](self.layer16[0](x, attention_mask), x) + + if (self.start_layer < 17) and (self.end_layer >= 17): + x = self.layer17[1](self.layer17[0](x), x) + + if (self.start_layer < 18) and (self.end_layer >= 18): + x = self.layer18[1](self.layer18[0](x, attention_mask), x) + + if (self.start_layer < 19) and (self.end_layer >= 19): + x = self.layer19[1](self.layer19[0](x), x) + + if (self.start_layer < 20) and (self.end_layer >= 20): + x = self.layer20[1](self.layer20[0](x, attention_mask), x) + + if (self.start_layer < 21) and (self.end_layer >= 21): + x = self.layer21[1](self.layer21[0](x), x) + + if (self.start_layer < 22) and (self.end_layer >= 22): + x = self.layer22[1](self.layer22[0](x, attention_mask), x) + + if (self.start_layer < 23) and (self.end_layer >= 23): + x = self.layer23[1](self.layer23[0](x), x) + + if (self.start_layer < 24) and (self.end_layer >= 24): + x = self.layer24[1](self.layer24[0](x, attention_mask), x) + + if (self.start_layer < 25) and (self.end_layer >= 25): + x = self.layer25[1](self.layer25[0](x), x) + + if (self.start_layer < 26) and (self.end_layer >= 26): + x = self.layer26(x) + + if (self.start_layer < 27) and (self.end_layer >= 27): + x = self.layer27(x) + + return x diff --git a/other/2LS/src/model/MobileNetv1_CIFAR10.py b/other/2LS/src/model/MobileNetv1_CIFAR10.py new file mode 100644 index 0000000..e6f783a --- /dev/null +++ b/other/2LS/src/model/MobileNetv1_CIFAR10.py @@ -0,0 +1,185 @@ +import torch +import torch.nn as nn + +class MobileNetv1_CIFAR10(nn.Module): + def __init__(self, start_layer=0, end_layer=84): + super(MobileNetv1_CIFAR10, self).__init__() + self.start_layer = start_layer + self.end_layer = end_layer + + if start_layer < 1 <= end_layer: + self.layer1 = nn.Conv2d(3, 32, 3, 1, 1) + if start_layer < 2 <= end_layer: + self.layer2 = nn.BatchNorm2d(32) + if start_layer < 3 <= end_layer: + self.layer3 = nn.ReLU() + if start_layer < 4 <= end_layer: + self.layer4 = nn.Conv2d(32, 32, 3, 1, 1) + if start_layer < 5 <= end_layer: + self.layer5 = nn.BatchNorm2d(32) + if start_layer < 6 <= end_layer: + self.layer6 = nn.ReLU() + if start_layer < 7 <= end_layer: + self.layer7 = nn.Conv2d(32, 64, 1) + if start_layer < 8 <= end_layer: + self.layer8 = nn.BatchNorm2d(64) + if start_layer < 9 <= end_layer: + self.layer9 = nn.ReLU() + if start_layer < 10 <= end_layer: + self.layer10 = nn.Conv2d(64, 64, 3, 2, 1) + if start_layer < 11 <= end_layer: + self.layer11 = nn.BatchNorm2d(64) + if start_layer < 12 <= end_layer: + self.layer12 = nn.ReLU() + if start_layer < 13 <= end_layer: + self.layer13 = nn.Conv2d(64, 128, 1) + if start_layer < 14 <= end_layer: + self.layer14 = nn.BatchNorm2d(128) + if start_layer < 15 <= end_layer: + self.layer15 = nn.ReLU() + if start_layer < 16 <= end_layer: + self.layer16 = nn.Conv2d(128, 128, 3, 1, 1) + if start_layer < 17 <= end_layer: + self.layer17 = nn.BatchNorm2d(128) + if start_layer < 18 <= end_layer: + self.layer18 = nn.ReLU() + if start_layer < 19 <= end_layer: + self.layer19 = nn.Conv2d(128, 128, 1) + if start_layer < 20 <= end_layer: + self.layer20 = nn.BatchNorm2d(128) + if start_layer < 21 <= end_layer: + self.layer21 = nn.ReLU() + if start_layer < 22 <= end_layer: + self.layer22 = nn.Conv2d(128, 128, 3, 2, 1) + if start_layer < 23 <= end_layer: + self.layer23 = nn.BatchNorm2d(128) + if start_layer < 24 <= end_layer: + self.layer24 = nn.ReLU() + if start_layer < 25 <= end_layer: + self.layer25 = nn.Conv2d(128, 256, 1) + if start_layer < 26 <= end_layer: + self.layer26 = nn.BatchNorm2d(256) + if start_layer < 27 <= end_layer: + self.layer27 = nn.ReLU() + if start_layer < 28 <= end_layer: + self.layer28 = nn.Conv2d(256, 256, 3, 1, 1) + if start_layer < 29 <= end_layer: + self.layer29 = nn.BatchNorm2d(256) + if start_layer < 30 <= end_layer: + self.layer30 = nn.ReLU() + if start_layer < 31 <= end_layer: + self.layer31 = nn.Conv2d(256, 256, 1) + if start_layer < 32 <= end_layer: + self.layer32 = nn.BatchNorm2d(256) + if start_layer < 33 <= end_layer: + self.layer33 = nn.ReLU() + if start_layer < 34 <= end_layer: + self.layer34 = nn.Conv2d(256, 256, 3, 2, 1) + if start_layer < 35 <= end_layer: + self.layer35 = nn.BatchNorm2d(256) + if start_layer < 36 <= end_layer: + self.layer36 = nn.ReLU() + if start_layer < 37 <= end_layer: + self.layer37 = nn.Conv2d(256, 512, 1) + if start_layer < 38 <= end_layer: + self.layer38 = nn.BatchNorm2d(512) + if start_layer < 39 <= end_layer: + self.layer39 = nn.ReLU() + if start_layer < 40 <= end_layer: + self.layer40 = nn.Conv2d(512, 512, 3, 1, 1) + if start_layer < 41 <= end_layer: + self.layer41 = nn.BatchNorm2d(512) + if start_layer < 42 <= end_layer: + self.layer42 = nn.ReLU() + if start_layer < 43 <= end_layer: + self.layer43 = nn.Conv2d(512, 512, 1) + if start_layer < 44 <= end_layer: + self.layer44 = nn.BatchNorm2d(512) + if start_layer < 45 <= end_layer: + self.layer45 = nn.ReLU() + if start_layer < 46 <= end_layer: + self.layer46 = nn.Conv2d(512, 512, 3, 1, 1) + if start_layer < 47 <= end_layer: + self.layer47 = nn.BatchNorm2d(512) + if start_layer < 48 <= end_layer: + self.layer48 = nn.ReLU() + if start_layer < 49 <= end_layer: + self.layer49 = nn.Conv2d(512, 512, 1) + if start_layer < 50 <= end_layer: + self.layer50 = nn.BatchNorm2d(512) + if start_layer < 51 <= end_layer: + self.layer51 = nn.ReLU() + if start_layer < 52 <= end_layer: + self.layer52 = nn.Conv2d(512, 512, 3, 1, 1) + if start_layer < 53 <= end_layer: + self.layer53 = nn.BatchNorm2d(512) + if start_layer < 54 <= end_layer: + self.layer54 = nn.ReLU() + if start_layer < 55 <= end_layer: + self.layer55 = nn.Conv2d(512, 512, 1) + if start_layer < 56 <= end_layer: + self.layer56 = nn.BatchNorm2d(512) + if start_layer < 57 <= end_layer: + self.layer57 = nn.ReLU() + if start_layer < 58 <= end_layer: + self.layer58 = nn.Conv2d(512, 512, 3, 1, 1) + if start_layer < 59 <= end_layer: + self.layer59 = nn.BatchNorm2d(512) + if start_layer < 60 <= end_layer: + self.layer60 = nn.ReLU() + if start_layer < 61 <= end_layer: + self.layer61 = nn.Conv2d(512, 512, 1) + if start_layer < 62 <= end_layer: + self.layer62 = nn.BatchNorm2d(512) + if start_layer < 63 <= end_layer: + self.layer63 = nn.ReLU() + if start_layer < 64 <= end_layer: + self.layer64 = nn.Conv2d(512, 512, 3, 1, 1) + if start_layer < 65 <= end_layer: + self.layer65 = nn.BatchNorm2d(512) + if start_layer < 66 <= end_layer: + self.layer66 = nn.ReLU() + if start_layer < 67 <= end_layer: + self.layer67 = nn.Conv2d(512, 512, 1) + if start_layer < 68 <= end_layer: + self.layer68 = nn.BatchNorm2d(512) + if start_layer < 69 <= end_layer: + self.layer69 = nn.ReLU() + if start_layer < 70 <= end_layer: + self.layer70 = nn.Conv2d(512, 512, 3, 2, 1) + if start_layer < 71 <= end_layer: + self.layer71 = nn.BatchNorm2d(512) + if start_layer < 72 <= end_layer: + self.layer72 = nn.ReLU() + if start_layer < 73 <= end_layer: + self.layer73 = nn.Conv2d(512, 1024, 1) + if start_layer < 74 <= end_layer: + self.layer74 = nn.BatchNorm2d(1024) + if start_layer < 75 <= end_layer: + self.layer75 = nn.ReLU() + if start_layer < 76 <= end_layer: + self.layer76 = nn.Conv2d(1024, 1024, 3, 1, 1) + if start_layer < 77 <= end_layer: + self.layer77 = nn.BatchNorm2d(1024) + if start_layer < 78 <= end_layer: + self.layer78 = nn.ReLU() + if start_layer < 79 <= end_layer: + self.layer79 = nn.Conv2d(1024, 1024, 1) + if start_layer < 80 <= end_layer: + self.layer80 = nn.BatchNorm2d(1024) + if start_layer < 81 <= end_layer: + self.layer81 = nn.ReLU() + if start_layer < 82 <= end_layer: + self.layer82 = nn.MaxPool2d(2, 2) + if start_layer < 83 <= end_layer: + self.layer83 = nn.Flatten(1, -1) + if start_layer < 84 <= end_layer: + self.layer84 = nn.Linear(1024, 10) + + def forward(self, x): + for i in range(1, 85): + if self.start_layer < i <= self.end_layer: + layer = getattr(self, f'layer{i}', None) + if layer is not None: + x = layer(x) + return x diff --git a/other/2LS/src/model/MobileNetv1_MNIST.py b/other/2LS/src/model/MobileNetv1_MNIST.py new file mode 100644 index 0000000..880a6a3 --- /dev/null +++ b/other/2LS/src/model/MobileNetv1_MNIST.py @@ -0,0 +1,185 @@ +import torch +import torch.nn as nn + +class MobileNetv1_MNIST(nn.Module): + def __init__(self, start_layer=0, end_layer=84): + super(MobileNetv1_MNIST, self).__init__() + self.start_layer = start_layer + self.end_layer = end_layer + + if start_layer < 1 <= end_layer: + self.layer1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1) + if start_layer < 2 <= end_layer: + self.layer2 = nn.BatchNorm2d(32) + if start_layer < 3 <= end_layer: + self.layer3 = nn.ReLU() + if start_layer < 4 <= end_layer: + self.layer4 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) + if start_layer < 5 <= end_layer: + self.layer5 = nn.BatchNorm2d(32) + if start_layer < 6 <= end_layer: + self.layer6 = nn.ReLU() + if start_layer < 7 <= end_layer: + self.layer7 = nn.Conv2d(32, 64, kernel_size=1, stride=1) + if start_layer < 8 <= end_layer: + self.layer8 = nn.BatchNorm2d(64) + if start_layer < 9 <= end_layer: + self.layer9 = nn.ReLU() + if start_layer < 10 <= end_layer: + self.layer10 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1) + if start_layer < 11 <= end_layer: + self.layer11 = nn.BatchNorm2d(64) + if start_layer < 12 <= end_layer: + self.layer12 = nn.ReLU() + if start_layer < 13 <= end_layer: + self.layer13 = nn.Conv2d(64, 128, kernel_size=1, stride=1) + if start_layer < 14 <= end_layer: + self.layer14 = nn.BatchNorm2d(128) + if start_layer < 15 <= end_layer: + self.layer15 = nn.ReLU() + if start_layer < 16 <= end_layer: + self.layer16 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + if start_layer < 17 <= end_layer: + self.layer17 = nn.BatchNorm2d(128) + if start_layer < 18 <= end_layer: + self.layer18 = nn.ReLU() + if start_layer < 19 <= end_layer: + self.layer19 = nn.Conv2d(128, 128, kernel_size=1, stride=1) + if start_layer < 20 <= end_layer: + self.layer20 = nn.BatchNorm2d(128) + if start_layer < 21 <= end_layer: + self.layer21 = nn.ReLU() + if start_layer < 22 <= end_layer: + self.layer22 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1) + if start_layer < 23 <= end_layer: + self.layer23 = nn.BatchNorm2d(128) + if start_layer < 24 <= end_layer: + self.layer24 = nn.ReLU() + if start_layer < 25 <= end_layer: + self.layer25 = nn.Conv2d(128, 256, kernel_size=1, stride=1) + if start_layer < 26 <= end_layer: + self.layer26 = nn.BatchNorm2d(256) + if start_layer < 27 <= end_layer: + self.layer27 = nn.ReLU() + if start_layer < 28 <= end_layer: + self.layer28 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + if start_layer < 29 <= end_layer: + self.layer29 = nn.BatchNorm2d(256) + if start_layer < 30 <= end_layer: + self.layer30 = nn.ReLU() + if start_layer < 31 <= end_layer: + self.layer31 = nn.Conv2d(256, 256, kernel_size=1, stride=1) + if start_layer < 32 <= end_layer: + self.layer32 = nn.BatchNorm2d(256) + if start_layer < 33 <= end_layer: + self.layer33 = nn.ReLU() + if start_layer < 34 <= end_layer: + self.layer34 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1) + if start_layer < 35 <= end_layer: + self.layer35 = nn.BatchNorm2d(256) + if start_layer < 36 <= end_layer: + self.layer36 = nn.ReLU() + if start_layer < 37 <= end_layer: + self.layer37 = nn.Conv2d(256, 512, kernel_size=1, stride=1) + if start_layer < 38 <= end_layer: + self.layer38 = nn.BatchNorm2d(512) + if start_layer < 39 <= end_layer: + self.layer39 = nn.ReLU() + if start_layer < 40 <= end_layer: + self.layer40 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 41 <= end_layer: + self.layer41 = nn.BatchNorm2d(512) + if start_layer < 42 <= end_layer: + self.layer42 = nn.ReLU() + if start_layer < 43 <= end_layer: + self.layer43 = nn.Conv2d(512, 512, kernel_size=1, stride=1) + if start_layer < 44 <= end_layer: + self.layer44 = nn.BatchNorm2d(512) + if start_layer < 45 <= end_layer: + self.layer45 = nn.ReLU() + if start_layer < 46 <= end_layer: + self.layer46 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 47 <= end_layer: + self.layer47 = nn.BatchNorm2d(512) + if start_layer < 48 <= end_layer: + self.layer48 = nn.ReLU() + if start_layer < 49 <= end_layer: + self.layer49 = nn.Conv2d(512, 512, kernel_size=1, stride=1) + if start_layer < 50 <= end_layer: + self.layer50 = nn.BatchNorm2d(512) + if start_layer < 51 <= end_layer: + self.layer51 = nn.ReLU() + if start_layer < 52 <= end_layer: + self.layer52 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 53 <= end_layer: + self.layer53 = nn.BatchNorm2d(512) + if start_layer < 54 <= end_layer: + self.layer54 = nn.ReLU() + if start_layer < 55 <= end_layer: + self.layer55 = nn.Conv2d(512, 512, kernel_size=1, stride=1) + if start_layer < 56 <= end_layer: + self.layer56 = nn.BatchNorm2d(512) + if start_layer < 57 <= end_layer: + self.layer57 = nn.ReLU() + if start_layer < 58 <= end_layer: + self.layer58 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 59 <= end_layer: + self.layer59 = nn.BatchNorm2d(512) + if start_layer < 60 <= end_layer: + self.layer60 = nn.ReLU() + if start_layer < 61 <= end_layer: + self.layer61 = nn.Conv2d(512, 512, kernel_size=1, stride=1) + if start_layer < 62 <= end_layer: + self.layer62 = nn.BatchNorm2d(512) + if start_layer < 63 <= end_layer: + self.layer63 = nn.ReLU() + if start_layer < 64 <= end_layer: + self.layer64 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 65 <= end_layer: + self.layer65 = nn.BatchNorm2d(512) + if start_layer < 66 <= end_layer: + self.layer66 = nn.ReLU() + if start_layer < 67 <= end_layer: + self.layer67 = nn.Conv2d(512, 512, kernel_size=1, stride=1) + if start_layer < 68 <= end_layer: + self.layer68 = nn.BatchNorm2d(512) + if start_layer < 69 <= end_layer: + self.layer69 = nn.ReLU() + if start_layer < 70 <= end_layer: + self.layer70 = nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1) + if start_layer < 71 <= end_layer: + self.layer71 = nn.BatchNorm2d(512) + if start_layer < 72 <= end_layer: + self.layer72 = nn.ReLU() + if start_layer < 73 <= end_layer: + self.layer73 = nn.Conv2d(512, 1024, kernel_size=1, stride=1) + if start_layer < 74 <= end_layer: + self.layer74 = nn.BatchNorm2d(1024) + if start_layer < 75 <= end_layer: + self.layer75 = nn.ReLU() + if start_layer < 76 <= end_layer: + self.layer76 = nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1) + if start_layer < 77 <= end_layer: + self.layer77 = nn.BatchNorm2d(1024) + if start_layer < 78 <= end_layer: + self.layer78 = nn.ReLU() + if start_layer < 79 <= end_layer: + self.layer79 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1) + if start_layer < 80 <= end_layer: + self.layer80 = nn.BatchNorm2d(1024) + if start_layer < 81 <= end_layer: + self.layer81 = nn.ReLU() + if start_layer < 82 <= end_layer: + self.layer82 = nn.MaxPool2d(2, 2) + if start_layer < 83 <= end_layer: + self.layer83 = nn.Flatten(1, -1) + if start_layer < 84 <= end_layer: + self.layer84 = nn.Linear(1024, 10) + + def forward(self, x): + for i in range(1, 85): + if self.start_layer < i <= self.end_layer: + layer = getattr(self, f'layer{i}', None) + if layer is not None: + x = layer(x) + return x diff --git a/other/2LS/src/model/VGG16_CIFAR10.py b/other/2LS/src/model/VGG16_CIFAR10.py new file mode 100644 index 0000000..e84652c --- /dev/null +++ b/other/2LS/src/model/VGG16_CIFAR10.py @@ -0,0 +1,230 @@ +import torch.nn as nn + +class VGG16_CIFAR10(nn.Module): + def __init__(self, start_layer=0, end_layer=52): + super(VGG16_CIFAR10, self).__init__() + self.start_layer = start_layer + self.end_layer = end_layer + + if start_layer < 1 <= end_layer: + self.layer1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + if start_layer < 2 <= end_layer: + self.layer2 = nn.BatchNorm2d(64) + if start_layer < 3 <= end_layer: + self.layer3 = nn.ReLU() + if start_layer < 4 <= end_layer: + self.layer4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + if start_layer < 5 <= end_layer: + self.layer5 = nn.BatchNorm2d(64) + if start_layer < 6 <= end_layer: + self.layer6 = nn.ReLU() + if start_layer < 7 <= end_layer: + self.layer7 = nn.MaxPool2d(kernel_size=2, stride=2) + + if start_layer < 8 <= end_layer: + self.layer8 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + if start_layer < 9 <= end_layer: + self.layer9 = nn.BatchNorm2d(128) + if start_layer < 10 <= end_layer: + self.layer10 = nn.ReLU() + if start_layer < 11 <= end_layer: + self.layer11 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + if start_layer < 12 <= end_layer: + self.layer12 = nn.BatchNorm2d(128) + if start_layer < 13 <= end_layer: + self.layer13 = nn.ReLU() + if start_layer < 14 <= end_layer: + self.layer14 = nn.MaxPool2d(kernel_size=2, stride=2) + + if start_layer < 15 <= end_layer: + self.layer15 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) + if start_layer < 16 <= end_layer: + self.layer16 = nn.BatchNorm2d(256) + if start_layer < 17 <= end_layer: + self.layer17 = nn.ReLU() + if start_layer < 18 <= end_layer: + self.layer18 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + if start_layer < 19 <= end_layer: + self.layer19 = nn.BatchNorm2d(256) + if start_layer < 20 <= end_layer: + self.layer20 = nn.ReLU() + if start_layer < 21 <= end_layer: + self.layer21 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + if start_layer < 22 <= end_layer: + self.layer22 = nn.BatchNorm2d(256) + if start_layer < 23 <= end_layer: + self.layer23 = nn.ReLU() + if start_layer < 24 <= end_layer: + self.layer24 = nn.MaxPool2d(kernel_size=2, stride=2) + + if start_layer < 25 <= end_layer: + self.layer25 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 26 <= end_layer: + self.layer26 = nn.BatchNorm2d(512) + if start_layer < 27 <= end_layer: + self.layer27 = nn.ReLU() + if start_layer < 28 <= end_layer: + self.layer28 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 29 <= end_layer: + self.layer29 = nn.BatchNorm2d(512) + if start_layer < 30 <= end_layer: + self.layer30 = nn.ReLU() + if start_layer < 31 <= end_layer: + self.layer31 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 32 <= end_layer: + self.layer32 = nn.BatchNorm2d(512) + if start_layer < 33 <= end_layer: + self.layer33 = nn.ReLU() + if start_layer < 34 <= end_layer: + self.layer34 = nn.MaxPool2d(kernel_size=2, stride=2) + + if start_layer < 35 <= end_layer: + self.layer35 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 36 <= end_layer: + self.layer36 = nn.BatchNorm2d(512) + if start_layer < 37 <= end_layer: + self.layer37 = nn.ReLU() + if start_layer < 38 <= end_layer: + self.layer38 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 39 <= end_layer: + self.layer39 = nn.BatchNorm2d(512) + if start_layer < 40 <= end_layer: + self.layer40 = nn.ReLU() + if start_layer < 41 <= end_layer: + self.layer41 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 42 <= end_layer: + self.layer42 = nn.BatchNorm2d(512) + if start_layer < 43 <= end_layer: + self.layer43 = nn.ReLU() + if start_layer < 44 <= end_layer: + self.layer44 = nn.MaxPool2d(kernel_size=2, stride=2) + + if start_layer < 45 <= end_layer: + self.layer45 = nn.Flatten(1, -1) + if start_layer < 46 <= end_layer: + self.layer46 = nn.Dropout(0.5) + if start_layer < 47 <= end_layer: + self.layer47 = nn.Linear(512 * 1 * 1, 4096) + if start_layer < 48 <= end_layer: + self.layer48 = nn.ReLU() + if start_layer < 49 <= end_layer: + self.layer49 = nn.Dropout(0.5) + if start_layer < 50 <= end_layer: + self.layer50 = nn.Linear(4096, 4096) + if start_layer < 51 <= end_layer: + self.layer51 = nn.ReLU() + if start_layer < 52 <= end_layer: + self.layer52 = nn.Linear(4096, 10) + + def forward(self, x): + if self.start_layer < 1 <= self.end_layer: + x = self.layer1(x) + if self.start_layer < 2 <= self.end_layer: + x = self.layer2(x) + if self.start_layer < 3 <= self.end_layer: + x = self.layer3(x) + if self.start_layer < 4 <= self.end_layer: + x = self.layer4(x) + if self.start_layer < 5 <= self.end_layer: + x = self.layer5(x) + if self.start_layer < 6 <= self.end_layer: + x = self.layer6(x) + if self.start_layer < 7 <= self.end_layer: + x = self.layer7(x) + + if self.start_layer < 8 <= self.end_layer: + x = self.layer8(x) + if self.start_layer < 9 <= self.end_layer: + x = self.layer9(x) + if self.start_layer < 10 <= self.end_layer: + x = self.layer10(x) + if self.start_layer < 11 <= self.end_layer: + x = self.layer11(x) + if self.start_layer < 12 <= self.end_layer: + x = self.layer12(x) + if self.start_layer < 13 <= self.end_layer: + x = self.layer13(x) + if self.start_layer < 14 <= self.end_layer: + x = self.layer14(x) + + if self.start_layer < 15 <= self.end_layer: + x = self.layer15(x) + if self.start_layer < 16 <= self.end_layer: + x = self.layer16(x) + if self.start_layer < 17 <= self.end_layer: + x = self.layer17(x) + if self.start_layer < 18 <= self.end_layer: + x = self.layer18(x) + if self.start_layer < 19 <= self.end_layer: + x = self.layer19(x) + if self.start_layer < 20 <= self.end_layer: + x = self.layer20(x) + if self.start_layer < 21 <= self.end_layer: + x = self.layer21(x) + if self.start_layer < 22 <= self.end_layer: + x = self.layer22(x) + if self.start_layer < 23 <= self.end_layer: + x = self.layer23(x) + if self.start_layer < 24 <= self.end_layer: + x = self.layer24(x) + + if self.start_layer < 25 <= self.end_layer: + x = self.layer25(x) + if self.start_layer < 26 <= self.end_layer: + x = self.layer26(x) + if self.start_layer < 27 <= self.end_layer: + x = self.layer27(x) + if self.start_layer < 28 <= self.end_layer: + x = self.layer28(x) + if self.start_layer < 29 <= self.end_layer: + x = self.layer29(x) + if self.start_layer < 30 <= self.end_layer: + x = self.layer30(x) + if self.start_layer < 31 <= self.end_layer: + x = self.layer31(x) + if self.start_layer < 32 <= self.end_layer: + x = self.layer32(x) + if self.start_layer < 33 <= self.end_layer: + x = self.layer33(x) + if self.start_layer < 34 <= self.end_layer: + x = self.layer34(x) + + if self.start_layer < 35 <= self.end_layer: + x = self.layer35(x) + if self.start_layer < 36 <= self.end_layer: + x = self.layer36(x) + if self.start_layer < 37 <= self.end_layer: + x = self.layer37(x) + if self.start_layer < 38 <= self.end_layer: + x = self.layer38(x) + if self.start_layer < 39 <= self.end_layer: + x = self.layer39(x) + if self.start_layer < 40 <= self.end_layer: + x = self.layer40(x) + if self.start_layer < 41 <= self.end_layer: + x = self.layer41(x) + if self.start_layer < 42 <= self.end_layer: + x = self.layer42(x) + if self.start_layer < 43 <= self.end_layer: + x = self.layer43(x) + if self.start_layer < 44 <= self.end_layer: + x = self.layer44(x) + + if self.start_layer < 45 <= self.end_layer: + x = self.layer45(x) + if self.start_layer < 46 <= self.end_layer: + x = self.layer46(x) + if self.start_layer < 47 <= self.end_layer: + x = self.layer47(x) + if self.start_layer < 48 <= self.end_layer: + x = self.layer48(x) + if self.start_layer < 49 <= self.end_layer: + x = self.layer49(x) + if self.start_layer < 50 <= self.end_layer: + x = self.layer50(x) + if self.start_layer < 51 <= self.end_layer: + x = self.layer51(x) + if self.start_layer < 52 <= self.end_layer: + x = self.layer52(x) + + return x diff --git a/other/2LS/src/model/VGG16_MNIST.py b/other/2LS/src/model/VGG16_MNIST.py new file mode 100644 index 0000000..1cc070e --- /dev/null +++ b/other/2LS/src/model/VGG16_MNIST.py @@ -0,0 +1,226 @@ +import torch.nn as nn + +class VGG16_MNIST(nn.Module): + def __init__(self, start_layer=0, end_layer=51): + super(VGG16_MNIST, self).__init__() + self.start_layer = start_layer + self.end_layer = end_layer + + if start_layer < 1 <= end_layer: + self.layer1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1) + if start_layer < 2 <= end_layer: + self.layer2 = nn.BatchNorm2d(64) + if start_layer < 3 <= end_layer: + self.layer3 = nn.ReLU() + if start_layer < 4 <= end_layer: + self.layer4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + if start_layer < 5 <= end_layer: + self.layer5 = nn.BatchNorm2d(64) + if start_layer < 6 <= end_layer: + self.layer6 = nn.ReLU() + if start_layer < 7 <= end_layer: + self.layer7 = nn.MaxPool2d(kernel_size=2, stride=2) + + if start_layer < 8 <= end_layer: + self.layer8 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + if start_layer < 9 <= end_layer: + self.layer9 = nn.BatchNorm2d(128) + if start_layer < 10 <= end_layer: + self.layer10 = nn.ReLU() + if start_layer < 11 <= end_layer: + self.layer11 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + if start_layer < 12 <= end_layer: + self.layer12 = nn.BatchNorm2d(128) + if start_layer < 13 <= end_layer: + self.layer13 = nn.ReLU() + if start_layer < 14 <= end_layer: + self.layer14 = nn.MaxPool2d(kernel_size=2, stride=2) + + if start_layer < 15 <= end_layer: + self.layer15 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) + if start_layer < 16 <= end_layer: + self.layer16 = nn.BatchNorm2d(256) + if start_layer < 17 <= end_layer: + self.layer17 = nn.ReLU() + if start_layer < 18 <= end_layer: + self.layer18 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + if start_layer < 19 <= end_layer: + self.layer19 = nn.BatchNorm2d(256) + if start_layer < 20 <= end_layer: + self.layer20 = nn.ReLU() + if start_layer < 21 <= end_layer: + self.layer21 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + if start_layer < 22 <= end_layer: + self.layer22 = nn.BatchNorm2d(256) + if start_layer < 23 <= end_layer: + self.layer23 = nn.ReLU() + if start_layer < 24 <= end_layer: + self.layer24 = nn.MaxPool2d(kernel_size=2, stride=2) + + if start_layer < 25 <= end_layer: + self.layer25 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 26 <= end_layer: + self.layer26 = nn.BatchNorm2d(512) + if start_layer < 27 <= end_layer: + self.layer27 = nn.ReLU() + if start_layer < 28 <= end_layer: + self.layer28 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 29 <= end_layer: + self.layer29 = nn.BatchNorm2d(512) + if start_layer < 30 <= end_layer: + self.layer30 = nn.ReLU() + if start_layer < 31 <= end_layer: + self.layer31 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 32 <= end_layer: + self.layer32 = nn.BatchNorm2d(512) + if start_layer < 33 <= end_layer: + self.layer33 = nn.ReLU() + if start_layer < 34 <= end_layer: + self.layer34 = nn.MaxPool2d(kernel_size=2, stride=2) + + if start_layer < 35 <= end_layer: + self.layer35 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 36 <= end_layer: + self.layer36 = nn.BatchNorm2d(512) + if start_layer < 37 <= end_layer: + self.layer37 = nn.ReLU() + if start_layer < 38 <= end_layer: + self.layer38 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 39 <= end_layer: + self.layer39 = nn.BatchNorm2d(512) + if start_layer < 40 <= end_layer: + self.layer40 = nn.ReLU() + if start_layer < 41 <= end_layer: + self.layer41 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 42 <= end_layer: + self.layer42 = nn.BatchNorm2d(512) + if start_layer < 43 <= end_layer: + self.layer43 = nn.ReLU() + + if start_layer < 44 <= end_layer: + self.layer44 = nn.Flatten(1, -1) + if start_layer < 45 <= end_layer: + self.layer45 = nn.Dropout(0.5) + if start_layer < 46 <= end_layer: + self.layer46 = nn.Linear(512, 4096) + if start_layer < 47 <= end_layer: + self.layer47 = nn.ReLU() + if start_layer < 48 <= end_layer: + self.layer48 = nn.Dropout(0.5) + if start_layer < 49 <= end_layer: + self.layer49 = nn.Linear(4096, 4096) + if start_layer < 50 <= end_layer: + self.layer50 = nn.ReLU() + if start_layer < 51 <= end_layer: + self.layer51 = nn.Linear(4096, 10) + + def forward(self, x): + if self.start_layer < 1 <= self.end_layer: + x = self.layer1(x) + if self.start_layer < 2 <= self.end_layer: + x = self.layer2(x) + if self.start_layer < 3 <= self.end_layer: + x = self.layer3(x) + if self.start_layer < 4 <= self.end_layer: + x = self.layer4(x) + if self.start_layer < 5 <= self.end_layer: + x = self.layer5(x) + if self.start_layer < 6 <= self.end_layer: + x = self.layer6(x) + if self.start_layer < 7 <= self.end_layer: + x = self.layer7(x) + + if self.start_layer < 8 <= self.end_layer: + x = self.layer8(x) + if self.start_layer < 9 <= self.end_layer: + x = self.layer9(x) + if self.start_layer < 10 <= self.end_layer: + x = self.layer10(x) + if self.start_layer < 11 <= self.end_layer: + x = self.layer11(x) + if self.start_layer < 12 <= self.end_layer: + x = self.layer12(x) + if self.start_layer < 13 <= self.end_layer: + x = self.layer13(x) + if self.start_layer < 14 <= self.end_layer: + x = self.layer14(x) + + if self.start_layer < 15 <= self.end_layer: + x = self.layer15(x) + if self.start_layer < 16 <= self.end_layer: + x = self.layer16(x) + if self.start_layer < 17 <= self.end_layer: + x = self.layer17(x) + if self.start_layer < 18 <= self.end_layer: + x = self.layer18(x) + if self.start_layer < 19 <= self.end_layer: + x = self.layer19(x) + if self.start_layer < 20 <= self.end_layer: + x = self.layer20(x) + if self.start_layer < 21 <= self.end_layer: + x = self.layer21(x) + if self.start_layer < 22 <= self.end_layer: + x = self.layer22(x) + if self.start_layer < 23 <= self.end_layer: + x = self.layer23(x) + if self.start_layer < 24 <= self.end_layer: + x = self.layer24(x) + + if self.start_layer < 25 <= self.end_layer: + x = self.layer25(x) + if self.start_layer < 26 <= self.end_layer: + x = self.layer26(x) + if self.start_layer < 27 <= self.end_layer: + x = self.layer27(x) + if self.start_layer < 28 <= self.end_layer: + x = self.layer28(x) + if self.start_layer < 29 <= self.end_layer: + x = self.layer29(x) + if self.start_layer < 30 <= self.end_layer: + x = self.layer30(x) + if self.start_layer < 31 <= self.end_layer: + x = self.layer31(x) + if self.start_layer < 32 <= self.end_layer: + x = self.layer32(x) + if self.start_layer < 33 <= self.end_layer: + x = self.layer33(x) + if self.start_layer < 34 <= self.end_layer: + x = self.layer34(x) + + if self.start_layer < 35 <= self.end_layer: + x = self.layer35(x) + if self.start_layer < 36 <= self.end_layer: + x = self.layer36(x) + if self.start_layer < 37 <= self.end_layer: + x = self.layer37(x) + if self.start_layer < 38 <= self.end_layer: + x = self.layer38(x) + if self.start_layer < 39 <= self.end_layer: + x = self.layer39(x) + if self.start_layer < 40 <= self.end_layer: + x = self.layer40(x) + if self.start_layer < 41 <= self.end_layer: + x = self.layer41(x) + if self.start_layer < 42 <= self.end_layer: + x = self.layer42(x) + if self.start_layer < 43 <= self.end_layer: + x = self.layer43(x) + + if self.start_layer < 44 <= self.end_layer: + x = self.layer44(x) + if self.start_layer < 45 <= self.end_layer: + x = self.layer45(x) + if self.start_layer < 46 <= self.end_layer: + x = self.layer46(x) + if self.start_layer < 47 <= self.end_layer: + x = self.layer47(x) + if self.start_layer < 48 <= self.end_layer: + x = self.layer48(x) + if self.start_layer < 49 <= self.end_layer: + x = self.layer49(x) + if self.start_layer < 50 <= self.end_layer: + x = self.layer50(x) + if self.start_layer < 51 <= self.end_layer: + x = self.layer51(x) + + return x diff --git a/other/2LS/src/model/ViT_CIFAR10.py b/other/2LS/src/model/ViT_CIFAR10.py new file mode 100644 index 0000000..098e206 --- /dev/null +++ b/other/2LS/src/model/ViT_CIFAR10.py @@ -0,0 +1,116 @@ +import torch +import torch.nn as nn +class TransformerEncoderBlock(nn.Module): + def __init__(self, embed_dim, num_heads=4, mlp_dim=256): + super().__init__() + self.ln1 = nn.LayerNorm(embed_dim) + self.mha = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) + self.ln2 = nn.LayerNorm(embed_dim) + self.mlp = nn.Sequential( + nn.Linear(embed_dim, mlp_dim), + nn.GELU(), + nn.Linear(mlp_dim, embed_dim) + ) + + def forward(self, x): + # Attention + residual + _x = self.ln1(x) + x_attn = self.mha(_x, _x, _x)[0] + x = x + x_attn + + # MLP + residual + x_mlp = self.mlp(self.ln2(x)) + x = x + x_mlp + return x + + +class ViT_CIFAR10(nn.Module): + + def __init__(self, start_layer=0, end_layer=12): + super(ViT_CIFAR10, self).__init__() + self.start_layer = start_layer + self.end_layer = end_layer + + img_size = 32 + patch_size = 4 + in_channels = 3 + embed_dim = 128 + num_classes = 10 + num_patches = (img_size // patch_size) ** 2 + + if (self.start_layer < 1) and (self.end_layer >= 1): + self.layer1 = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + + if (self.start_layer < 2) and (self.end_layer >= 2): + self.layer2 = nn.Flatten(2) + + if (self.start_layer < 3) and (self.end_layer >= 3): + self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) + + if (self.start_layer < 4) and (self.end_layer >= 4): + self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim)) + self.layer4 = nn.Identity() + + if (self.start_layer < 5) and self.end_layer >= 5: + self.layer5 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 6) and (self.end_layer >= 6): + self.layer6 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 7) and (self.end_layer >= 7): + self.layer7 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 8) and (self.end_layer >= 8): + self.layer8 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 9) and (self.end_layer >= 9): + self.layer9 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 10) and (self.end_layer >= 10): + self.layer10 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 11) and (self.end_layer >= 11): + self.layer11 = nn.LayerNorm(embed_dim) + + if (self.start_layer < 12) and (self.end_layer >= 12): + self.layer12 = nn.Linear(embed_dim, num_classes) + + def forward(self, x): + if self.start_layer < 1 <= self.end_layer: + x = self.layer1(x) + + if self.start_layer < 2 <= self.end_layer: + x = self.layer2(x) + x = x.transpose(1, 2) + + if self.start_layer < 3 <= self.end_layer: + cls_token = self.cls_token.expand(x.size(0), -1, -1) + x = torch.cat([cls_token, x], dim=1) + + if self.start_layer < 4 <= self.end_layer: + x = self.layer4(x + self.pos_embed) + + if self.start_layer < 5 <= self.end_layer: + x = self.layer5(x) + + if self.start_layer < 6 <= self.end_layer: + x = self.layer6(x) + + if self.start_layer < 7 <= self.end_layer: + x = self.layer7(x) + + if self.start_layer < 8 <= self.end_layer: + x = self.layer8(x) + + if self.start_layer < 9 <= self.end_layer: + x = self.layer9(x) + + if self.start_layer < 10 <= self.end_layer: + x = self.layer10(x) + + if self.start_layer < 11 <= self.end_layer: + x = self.layer11(x[:, 0]) + + if self.start_layer < 12 <= self.end_layer: + x = self.layer12(x) + return x diff --git a/other/2LS/src/model/ViT_MNIST.py b/other/2LS/src/model/ViT_MNIST.py new file mode 100644 index 0000000..5d21645 --- /dev/null +++ b/other/2LS/src/model/ViT_MNIST.py @@ -0,0 +1,116 @@ +import torch +import torch.nn as nn +class TransformerEncoderBlock(nn.Module): + def __init__(self, embed_dim, num_heads=4, mlp_dim=256): + super().__init__() + self.ln1 = nn.LayerNorm(embed_dim) + self.mha = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) + self.ln2 = nn.LayerNorm(embed_dim) + self.mlp = nn.Sequential( + nn.Linear(embed_dim, mlp_dim), + nn.GELU(), + nn.Linear(mlp_dim, embed_dim) + ) + + def forward(self, x): + # Attention + residual + _x = self.ln1(x) + x_attn = self.mha(_x, _x, _x)[0] + x = x + x_attn + + # MLP + residual + x_mlp = self.mlp(self.ln2(x)) + x = x + x_mlp + return x + + +class ViT_MNIST(nn.Module): + + def __init__(self, start_layer=0, end_layer=12): + super(ViT_MNIST, self).__init__() + self.start_layer = start_layer + self.end_layer = end_layer + + img_size = 28 + patch_size = 4 + in_channels = 1 + embed_dim = 128 + num_classes = 10 + num_patches = (img_size // patch_size) ** 2 + + if (self.start_layer < 1) and (self.end_layer >= 1): + self.layer1 = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + + if (self.start_layer < 2) and (self.end_layer >= 2): + self.layer2 = nn.Flatten(2) + + if (self.start_layer < 3) and (self.end_layer >= 3): + self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) + + if (self.start_layer < 4) and (self.end_layer >= 4): + self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim)) + self.layer4 = nn.Identity() + + if (self.start_layer < 5) and self.end_layer >= 5: + self.layer5 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 6) and (self.end_layer >= 6): + self.layer6 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 7) and (self.end_layer >= 7): + self.layer7 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 8) and (self.end_layer >= 8): + self.layer8 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 9) and (self.end_layer >= 9): + self.layer9 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 10) and (self.end_layer >= 10): + self.layer10 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 11) and (self.end_layer >= 11): + self.layer11 = nn.LayerNorm(embed_dim) + + if (self.start_layer < 12) and (self.end_layer >= 12): + self.layer12 = nn.Linear(embed_dim, num_classes) + + def forward(self, x): + if self.start_layer < 1 <= self.end_layer: + x = self.layer1(x) + + if self.start_layer < 2 <= self.end_layer: + x = self.layer2(x) + x = x.transpose(1, 2) + + if self.start_layer < 3 <= self.end_layer: + cls_token = self.cls_token.expand(x.size(0), -1, -1) + x = torch.cat([cls_token, x], dim=1) + + if self.start_layer < 4 <= self.end_layer: + x = self.layer4(x + self.pos_embed) + + if self.start_layer < 5 <= self.end_layer: + x = self.layer5(x) + + if self.start_layer < 6 <= self.end_layer: + x = self.layer6(x) + + if self.start_layer < 7 <= self.end_layer: + x = self.layer7(x) + + if self.start_layer < 8 <= self.end_layer: + x = self.layer8(x) + + if self.start_layer < 9 <= self.end_layer: + x = self.layer9(x) + + if self.start_layer < 10 <= self.end_layer: + x = self.layer10(x) + + if self.start_layer < 11 <= self.end_layer: + x = self.layer11(x[:, 0]) + + if self.start_layer < 12 <= self.end_layer: + x = self.layer12(x) + return x diff --git a/other/2LS/src/model/__init__.py b/other/2LS/src/model/__init__.py new file mode 100644 index 0000000..b0304dd --- /dev/null +++ b/other/2LS/src/model/__init__.py @@ -0,0 +1,6 @@ +from .MobileNetv1_CIFAR10 import * +from .MobileNetv1_MNIST import * +from .VGG16_CIFAR10 import * +from .VGG16_MNIST import * +from .ViT_CIFAR10 import * +from .ViT_MNIST import * From a3d2bd31acb33e11f530f9783b23eb5fcf4f34c2 Mon Sep 17 00:00:00 2001 From: nanhhao04 Date: Thu, 26 Mar 2026 12:31:34 +0700 Subject: [PATCH 2/2] add kwt, bert --- other/2LS/src/Log.py | 33 ++- other/2LS/src/RpcClient.py | 87 ++++-- other/2LS/src/Scheduler.py | 10 +- other/2LS/src/Server.py | 323 ++++++++++++---------- other/2LS/src/Validation.py | 71 ++--- other/2LS/src/dataset/AGNEWS.py | 30 ++ other/2LS/src/dataset/SPEECHCOMMANDS.py | 222 +++++++++++++++ other/2LS/src/dataset/__init__.py | 0 other/2LS/src/dataset/dataloader.py | 171 ++++++++++++ other/2LS/src/model/Bert_AGNEWS.py | 231 ++++++++++++++++ other/2LS/src/model/KWT_SPEECHCOMMANDS.py | 94 +++++++ other/2LS/src/model/__init__.py | 3 + other/2LS/src/train/Bert.py | 167 +++++++++++ other/2LS/src/train/KWT.py | 165 +++++++++++ other/2LS/src/train/VGG16.py | 160 +++++++++++ other/2LS/src/train/__init__.py | 1 + other/2LS/src/val/Bert.py | 44 +++ other/2LS/src/val/KWT.py | 39 +++ other/2LS/src/val/VGG16.py | 39 +++ other/2LS/src/val/__init__.py | 0 other/2LS/src/val/get_val.py | 17 ++ 21 files changed, 1694 insertions(+), 213 deletions(-) create mode 100644 other/2LS/src/dataset/AGNEWS.py create mode 100644 other/2LS/src/dataset/SPEECHCOMMANDS.py create mode 100644 other/2LS/src/dataset/__init__.py create mode 100644 other/2LS/src/dataset/dataloader.py create mode 100644 other/2LS/src/model/Bert_AGNEWS.py create mode 100644 other/2LS/src/model/KWT_SPEECHCOMMANDS.py create mode 100644 other/2LS/src/train/Bert.py create mode 100644 other/2LS/src/train/KWT.py create mode 100644 other/2LS/src/train/VGG16.py create mode 100644 other/2LS/src/train/__init__.py create mode 100644 other/2LS/src/val/Bert.py create mode 100644 other/2LS/src/val/KWT.py create mode 100644 other/2LS/src/val/VGG16.py create mode 100644 other/2LS/src/val/__init__.py create mode 100644 other/2LS/src/val/get_val.py diff --git a/other/2LS/src/Log.py b/other/2LS/src/Log.py index 640ad13..f485cbc 100644 --- a/other/2LS/src/Log.py +++ b/other/2LS/src/Log.py @@ -1,4 +1,5 @@ import logging +import os class Colors: @@ -13,25 +14,37 @@ class Colors: class Logger: - def __init__(self, log_path, debug_mode=False): + def __init__(self, log_path, debug_mode=False, minimal=False): # Thiết lập logger với tên "my_logger" self.logger = logging.getLogger("my_logger") self.logger.setLevel(logging.DEBUG) # Mức log self.debug_mode = debug_mode + self.minimal = minimal - # Tạo file handler để ghi log vào file - file_handler = logging.FileHandler(log_path) - file_handler.setLevel(logging.DEBUG) + # Clear existing handlers to avoid duplicate logs if re-initialized + if self.logger.hasHandlers(): + self.logger.handlers.clear() - # Định dạng log - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - file_handler.setFormatter(formatter) + if log_path: + # Tạo thư mục log nếu chưa tồn tại + log_dir = os.path.dirname(log_path) + if log_dir: + os.makedirs(log_dir, exist_ok=True) - # Gắn file handler vào logger - self.logger.addHandler(file_handler) + # Tạo file handler để ghi log vào file + file_handler = logging.FileHandler(log_path) + file_handler.setLevel(logging.DEBUG) + + # Định dạng log + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + file_handler.setFormatter(formatter) + + # Gắn file handler vào logger + self.logger.addHandler(file_handler) def log_info(self, message): - print(f"[INFO] {message}") + if not self.minimal: + print(f"[INFO] {message}") self.logger.info(message) def log_warning(self, message): diff --git a/other/2LS/src/RpcClient.py b/other/2LS/src/RpcClient.py index 63825e2..7c0db29 100644 --- a/other/2LS/src/RpcClient.py +++ b/other/2LS/src/RpcClient.py @@ -4,6 +4,7 @@ import copy import torchvision import torchvision.transforms as transforms +import torch from collections import defaultdict from tqdm import tqdm @@ -11,6 +12,8 @@ import src.Log from src.model import * +from peft import LoraConfig, get_peft_model + class RpcClient: def __init__(self, client_id, layer_id, channel, train_func, device): @@ -23,6 +26,7 @@ def __init__(self, client_id, layer_id, channel, train_func, device): self.response = None self.model = None self.label_count = None + self.peft_config = None self.train_set = None self.label_to_indices = None @@ -75,49 +79,105 @@ def response_message(self, body): ]) self.train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + elif data_name == "SPEECHCOMMANDS": + from src.dataset.SPEECHCOMMANDS import SpeechCommandsDataset + self.train_set = SpeechCommandsDataset(root='./data', subset='training') + elif data_name == "AGNEWS": + from datasets import load_dataset + from transformers import BertTokenizer + from src.dataset.AGNEWS import AGNEWS_DATASET + + dataset = load_dataset('ag_news', download_mode='reuse_dataset_if_exists', cache_dir='./hf_cache') + tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + + train_data = dataset['train'] + texts = train_data['text'] + labels = train_data['label'] + + self.train_set = AGNEWS_DATASET(texts, labels, tokenizer, max_length=128) else: self.train_set = None raise ValueError(f"Data name '{data_name}' is not valid.") self.label_to_indices = defaultdict(list) - for idx, (_, label) in tqdm(enumerate(self.train_set)): - self.label_to_indices[int(label)].append(idx) + if hasattr(self.train_set, 'labels'): + for idx, label in enumerate(self.train_set.labels): + self.label_to_indices[int(label)].append(idx) + else: + for idx, (_, label) in tqdm(enumerate(self.train_set)): + self.label_to_indices[int(label)].append(idx) # Load model if self.model is None: - - klass = globals()[f'{model_name}_{data_name}'] + klass = globals().get(f'{model_name}_{data_name}') + if klass is None: + # try alternative names or mappings if needed + if model_name.upper() == 'BERT' and data_name == 'AGNEWS': + klass = Bert_AGNEWS + elif model_name == 'KWT': + klass = KWT_SPEECHCOMMANDS + + if klass is None: + raise ValueError(f"Model class for {model_name} and {data_name} not found.") if cut_layers[1] == -1: self.model = klass(start_layer=cut_layers[0]) else: - self.model = klass(start_layer=cut_layers[0], end_layer=cut_layers[1]) + if klass == Bert_AGNEWS: + self.model = klass(layer_id=1, n_block=cut_layers[1]) if self.layer_id == 1 else klass(layer_id=2, n_block=12 - cut_layers[0]) + else: + self.model = klass(start_layer=cut_layers[0], end_layer=cut_layers[1]) self.model.to(self.device) batch_size = self.response["batch_size"] lr = self.response["lr"] momentum = self.response["momentum"] - out_cluster_id = self.response.get("out_cluster_id", -1) + sda_size = self.response.get("sda_size", 1) + layer2_devices = self.response.get("layer2_devices", []) # Read parameters and load to model if state_dict: self.model.load_state_dict(state_dict) + # Apply LoRA for BERT model + if model_name.upper() == 'BERT': + if self.peft_config is None: + self.peft_config = LoraConfig( + task_type="SEQ_CLS", + r=8, lora_alpha=16, lora_dropout=0.1, + bias="none", + target_modules=["query", "key", "value", "dense"] + ) + self.model = get_peft_model(self.model, self.peft_config) + # Note: layer15 might be a placeholder or refer to specific layers in user's model + if self.layer_id == 2: + if hasattr(self.model, 'layer15'): + for param in self.model.layer15.parameters(): + param.requires_grad = True + + self.model.to(self.device) + # Start training if self.layer_id == 1: selected_indices = [] for label, count in enumerate(self.label_count): - selected_indices.extend(random.sample(self.label_to_indices[label], count)) + available = len(self.label_to_indices[label]) + actual_count = min(count, available) + if actual_count > 0: + selected_indices.extend(random.sample(self.label_to_indices[label], actual_count)) subset = torch.utils.data.Subset(self.train_set, selected_indices) train_loader = torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=True) - result, size = self.train_func(self.model, lr, momentum, train_loader, local_round=local_round, out_cluster_id=out_cluster_id) + result, size = self.train_func(self.model, lr, momentum, train_loader, local_round=local_round, layer2_devices=layer2_devices, model_name=model_name) else: - # Layer 2 handles data asynchronously from Layer 1, no local_round limit - result, size = self.train_func(self.model, lr, momentum, out_cluster_id=out_cluster_id) + result, size = self.train_func(self.model, lr, momentum, None, local_round=local_round, sda_size=sda_size, model_name=model_name) + + # Merge LoRA weights back for BERT + if model_name.upper() == 'BERT': + self.model = self.model.merge_and_unload() # Stop training, then send parameters to server model_state_dict = copy.deepcopy(self.model.state_dict()) @@ -131,18 +191,13 @@ def response_message(self, body): src.Log.print_with_color("[>>>] Client sent parameters to server", "red") self.send_to_server(data) return True - elif action == "PAUSE": - return True elif action == "STOP": return False + return True def send_to_server(self, message): - self.response = None - self.channel.queue_declare('rpc_queue', durable=False) self.channel.basic_publish(exchange='', routing_key='rpc_queue', body=pickle.dumps(message)) - - return self.response diff --git a/other/2LS/src/Scheduler.py b/other/2LS/src/Scheduler.py index b7c2186..d5c8a0e 100644 --- a/other/2LS/src/Scheduler.py +++ b/other/2LS/src/Scheduler.py @@ -63,7 +63,7 @@ def send_to_server(self, message): routing_key='rpc_queue', body=pickle.dumps(message)) - def train_on_first_layer(self, model, lr, momentum, train_loader=None, local_round=1, out_cluster_id=-1): + def train_on_first_layer(self, model, lr, momentum, train_loader=None, local_round=1, out_cluster_id=-1, **kwargs): optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) backward_queue_name = f'gradient_queue_{self.layer_id}_{self.client_id}' @@ -134,7 +134,7 @@ def train_on_first_layer(self, model, lr, momentum, train_loader=None, local_rou return True time.sleep(0.5) - def train_on_last_layer(self, model, lr, momentum, out_cluster_id=-1): + def train_on_last_layer(self, model, lr, momentum, out_cluster_id=-1, **kwargs): optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) result = True @@ -220,11 +220,11 @@ def train_on_last_layer(self, model, lr, momentum, out_cluster_id=-1): if received_data["action"] == "PAUSE": return result - def train_on_device(self, model, lr, momentum, train_loader=None, local_round=None, out_cluster_id=-1): + def train_on_device(self, model, lr, momentum, train_loader=None, local_round=None, out_cluster_id=-1, **kwargs): self.data_count = 0 if self.layer_id == 1: - result = self.train_on_first_layer(model, lr, momentum, train_loader, local_round, out_cluster_id=out_cluster_id) + result = self.train_on_first_layer(model, lr, momentum, train_loader, local_round, out_cluster_id=out_cluster_id, **kwargs) else: - result = self.train_on_last_layer(model, lr, momentum, out_cluster_id=out_cluster_id) + result = self.train_on_last_layer(model, lr, momentum, out_cluster_id=out_cluster_id, **kwargs) return result, self.data_count \ No newline at end of file diff --git a/other/2LS/src/Server.py b/other/2LS/src/Server.py index 20b371f..6ee2e04 100644 --- a/other/2LS/src/Server.py +++ b/other/2LS/src/Server.py @@ -1,4 +1,5 @@ import os +import time import random import pika import pickle @@ -7,7 +8,10 @@ import copy import src.Log import src.Utils -import src.Validation +from src.val.get_val import get_val +from src.model.Bert_AGNEWS import Bert_AGNEWS +from src.model.KWT_SPEECHCOMMANDS import KWT_SPEECHCOMMANDS +from src.model.VGG16_CIFAR10 import VGG16_CIFAR10 from src.model import * @@ -39,7 +43,7 @@ def __init__(self, config): self.current_out_cluster_cursor = 0 # Pointer current out-cluster self.current_out_cluster_idx = 0 self.finished_clients_in_cluster = {} # {(out_idx, in_idx): count} - self.finished_upper_clients_count = {} # {out_idx: count} — Phase 2 only + self.finished_upper_clients_count = {} # {out_idx: count} # FedAsync: track thứ tự in-cluster đến cho mỗi out-cluster self.incluster_fedasync_order = {} # {out_idx: [in_idx_first, in_idx_second, ...]} @@ -66,9 +70,21 @@ def __init__(self, config): log_path = config["log_path"] credentials = pika.PlainCredentials(username, password) - self.connection = pika.BlockingConnection( - pika.ConnectionParameters(address, 5672, f'{virtual_host}', credentials)) - self.channel = self.connection.channel() + try: + parameters = pika.ConnectionParameters( + host=address, + port=5672, + virtual_host=virtual_host, + credentials=credentials, + heartbeat=0, + socket_timeout=5 + ) + self.connection = pika.BlockingConnection(parameters) + self.channel = self.connection.channel() + src.Log.print_with_color(f"[OK] Server connected to RabbitMQ at {address}", "green") + except Exception as e: + src.Log.print_with_color(f"[ERROR] Server failed to connect to RabbitMQ at {address}: {e}", "red") + sys.exit(1) self.channel.queue_declare(queue='rpc_queue') self.current_clients = 0 @@ -92,24 +108,40 @@ def __init__(self, config): self.reply_channel = self.connection.channel() self.channel.basic_consume(queue='rpc_queue', on_message_callback=self.on_request) + self.start_time = time.time() debug_mode = config["debug_mode"] - self.logger = src.Log.Logger(f"{log_path}/app.log", debug_mode) + # Server logger now goes to server.log + self.server_logger = src.Log.Logger(f"{log_path}/server.log", debug_mode) + # Accuracy logger is removed, using server_logger instead + src.Log.print_with_color(f"Application start. Server is waiting for {self.total_clients} clients.", "green") - self.logger.log_info(f"Application start. Server is waiting for {self.total_clients} clients.") + self.server_logger.log_info(f"Application start. Server is waiting for {self.total_clients} clients.") def distribution(self): if self.non_iid: - label_distribution = np.array([[0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], - [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], - [0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], - [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], - [0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], - [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], - ]) + if self.data_name == "SPEECHCOMMANDS": + label_distribution = np.array([[0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25], + [0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25], + [0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25], + ]) + else: + label_distribution = np.array([[0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], + [0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], + [0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], + [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], + [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], + [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], + ]) self.label_counts = (label_distribution * self.num_sample).astype(int) self.label_ = copy.deepcopy(self.label_counts) @@ -144,7 +176,10 @@ def on_request(self, ch, method, props, body): self.distribution() filepath = f'{self.model_name}_{self.data_name}.pth' - initial_sd = torch.load(filepath, weights_only=True) if os.path.exists(filepath) else {} + if os.path.exists(filepath): + initial_sd = torch.load(filepath, weights_only=True) + else: + initial_sd = {} # node[3] là out_cluster_id unique_out_clusters = sorted(list(set(node[3] for node in self.list_clients))) # thứ tự 0 1 2 @@ -160,7 +195,7 @@ def on_request(self, ch, method, props, body): src.Log.print_with_color(f"All clients connected. Shuffled Out-cluster order: {self.out_cluster_order}", "green") src.Log.print_with_color("Hierarchical structure initialized from predefined IDs.", "green") - self.logger.log_info(f"Start training round {self.global_round - self.round + 1}") + self.server_logger.log_info(f"Start training round {self.global_round - self.round + 1}") self.notify_clients(register=False) elif action == "NOTIFY": @@ -241,39 +276,23 @@ def on_request(self, ch, method, props, body): order_list = self.incluster_fedasync_order[out_idx] order_list.append(in_idx) - alpha = 1.0 if len(order_list) == 1 else 0.5 - src.Log.print_with_color(f">>> In-cluster ({out_idx}, {in_idx}) arrived {'FIRST' if alpha == 1.0 else 'LATER'}. alpha={alpha}", "green") + alpha = 0.5 if len(order_list) == 1 else 0.25 + src.Log.print_with_color(f">>> In-cluster ({out_idx}, {in_idx}) arrived {'FIRST' if alpha == 0.5 else 'LATER'}. alpha={alpha}", "green") # Lưu L1 FedAvg result — CHƯA FedAsync, chờ L2 self.incluster_l1_avg[(out_idx, in_idx)] = in_cluster_avg_sd - # Pause L1 clients - message_pause_l1 = {"action": "PAUSE", "message": "In-cluster done. Waiting.", "parameters": None} - for cid in clients_in_cluster: - self.send_to_response(cid, pickle.dumps(message_pause_l1)) - # Gửi PAUSE cho L2 clients thuộc in-cluster này l2_clients_this_ic = [n for n in self.list_clients if int(n[1]) > 1 and int(n[2]) == in_idx] + message_pause_l2 = {"action": "PAUSE", "message": f"Send L2 parameters for in-cluster {in_idx}.", "parameters": None} + seen = set() + for node in reversed(l2_clients_this_ic): + role_key = (node[1], node[2], node[3], node[4]) + if role_key not in seen: + seen.add(role_key) + src.Log.print_with_color(f">>> Sending PAUSE to L2 client {node[0]} for in-cluster {in_idx}", "yellow") + self.send_to_response(node[0], pickle.dumps(message_pause_l2)) - if l2_clients_this_ic: - message_pause_l2 = {"action": "PAUSE", "message": f"Send L2 parameters for in-cluster {in_idx}.", "parameters": None} - seen = set() - for node in reversed(l2_clients_this_ic): - role_key = (node[1], node[2], node[3], node[4]) - if role_key not in seen: - seen.add(role_key) - src.Log.print_with_color(f">>> Sending PAUSE to L2 client {node[0]} for in-cluster {in_idx}", "yellow") - self.send_to_response(node[0], pickle.dumps(message_pause_l2)) - else: - # Không có L2 → FedAsync L1 trực tiếp (paper: model chỉ có L1) - src.Log.print_with_color(f">>> No L2 for in-cluster ({out_idx}, {in_idx}). FedAsync L1 only, alpha={alpha}", "green") - self.fedasync_aggregate(out_idx, in_cluster_avg_sd, alpha=alpha) - self.incluster_l1_avg.pop((out_idx, in_idx), None) - self.check_out_cluster_completion(out_idx) - - else: - message_pause = {"action": "PAUSE", "message": "Round mismatch. Waiting...", "parameters": None} - self.send_to_response(cid_str, pickle.dumps(message_pause)) elif layer_id > 1: # Accumulate L2 update per in-cluster @@ -310,7 +329,7 @@ def on_request(self, ch, method, props, body): # Lấy alpha theo thứ tự in-cluster đến order_list = self.incluster_fedasync_order.get(out_idx, []) arrival_pos = order_list.index(in_idx) if in_idx in order_list else 0 - alpha = 1.0 if arrival_pos == 0 else 0.5 + alpha = 0.5 if arrival_pos == 0 else 0.25 src.Log.print_with_color(f">>> In-cluster ({out_idx}, {in_idx}) L1+L2 merged. FedAsync alpha={alpha} (paper Alg.1)", "green") self.fedasync_aggregate(out_idx, merged_sd, alpha=alpha) @@ -318,65 +337,54 @@ def on_request(self, ch, method, props, body): self.check_out_cluster_completion(out_idx) def check_out_cluster_completion(self, out_idx): - # Kiểm tra tất cả in-cluster đã FedAsync xong - all_l1_in_oc = set(int(n[2]) for n in self.list_clients if int(n[1]) == 1 and int(n[3]) == out_idx) + all_in_clusters = set(int(n[2]) for n in self.list_clients if int(n[1]) == 1 and int(n[3]) == out_idx) order_list = self.incluster_fedasync_order.get(out_idx, []) - is_done = len(order_list) >= len(all_l1_in_oc) if all_l1_in_oc else True - - # Kiểm tra L2 per in-cluster: tất cả in-cluster phải hoàn thành L2 - all_l2_done = True - for ic_idx in all_l1_in_oc: - l2_count_for_ic = len(set((n[1], n[2], n[3], n[4]) for n in self.list_clients if int(n[1]) > 1 and int(n[2]) == ic_idx)) - if l2_count_for_ic > 0: - finished_l2_ic = self.incluster_l2_finished.get((out_idx, ic_idx), 0) - if finished_l2_ic < l2_count_for_ic: - all_l2_done = False - break - - if is_done and all_l2_done: - src.Log.print_with_color(f">>> Out-cluster {out_idx} FULLY completed (L1 & L2+).", "green") - - # Validation: tổng hợp mô hình và in accuracy - state_dict_full = self.out_cluster_models[out_idx] - if len(state_dict_full) > 0: - src.Log.print_with_color(f">>> Running validation for Out-cluster {out_idx}...", "yellow") - src.Validation.test(self.model_name, self.data_name, state_dict_full, self.logger) - - # Reset cho out-cluster này - self.incluster_fedasync_order[out_idx] = [] - self.finished_upper_clients_count[out_idx] = 0 - # Reset L2 per in-cluster counters - for ic_idx in all_l1_in_oc: - self.incluster_l2_finished.pop((out_idx, ic_idx), None) - self.incluster_l1_avg.pop((out_idx, ic_idx), None) - - # chuyển sang outcluster tiếp - self.current_out_cluster_cursor += 1 - - if self.current_out_cluster_cursor >= len(self.out_cluster_order): - # xong 1 round ( chạy qua hết outcluster) - self.round -= 1 - if self.round <= 0: - src.Log.print_with_color(">>> All global rounds completed.", "green") - state_dict_full = self.out_cluster_models[out_idx] - torch.save(state_dict_full, f'{self.model_name}_{self.data_name}.pth') - src.Log.print_with_color(">>> Server training process total completion.", "green") - return - - # next round - self.current_out_cluster_cursor = 0 - random.shuffle(self.out_cluster_order) - src.Log.print_with_color( - f">>> New Global Round. Shuffled Out-cluster order: {self.out_cluster_order}", "green") - # Set next out-cluster - next_out_idx = self.out_cluster_order[self.current_out_cluster_cursor] - self.out_cluster_models[next_out_idx] = copy.deepcopy(self.out_cluster_models[out_idx]) - self.current_out_cluster_idx = next_out_idx - - # Start next Out-cluster - src.Log.print_with_color(f">>> Moving to Out-cluster {self.current_out_cluster_idx}", "yellow") - self.notify_clients(register=False) + # Chưa đủ in-cluster hoàn thành + if all_in_clusters and len(order_list) < len(all_in_clusters): + return + # Kiểm tra L2 per in-cluster + for ic_idx in all_in_clusters: + l2_roles = len(set((n[1], n[2], n[3], n[4]) for n in self.list_clients if int(n[1]) > 1 and int(n[2]) == ic_idx)) + if l2_roles > 0 and self.incluster_l2_finished.get((out_idx, ic_idx), 0) < l2_roles: + return + + src.Log.print_with_color(f">>> Out-cluster {out_idx} FULLY completed (L1 & L2+).", "green") + + # Validation + sd = self.out_cluster_models[out_idx] + if sd: + current_r = self.global_round - self.round + 1 + elapsed_min = (time.time() - self.start_time) / 60 + self.server_logger.log_info(f"Round {current_r} ({elapsed_min:.2f} min):") + get_val(self.model_name, self.data_name, sd, self.server_logger) + + # Reset + self.incluster_fedasync_order[out_idx] = [] + self.finished_upper_clients_count[out_idx] = 0 + for ic_idx in all_in_clusters: + self.incluster_l2_finished.pop((out_idx, ic_idx), None) + self.incluster_l1_avg.pop((out_idx, ic_idx), None) + + # Chuyển sang out-cluster tiếp + self.current_out_cluster_cursor += 1 + if self.current_out_cluster_cursor >= len(self.out_cluster_order): + self.round -= 1 + if self.round <= 0: + src.Log.print_with_color(">>> All global rounds completed.", "green") + torch.save(sd, f'{self.model_name}_{self.data_name}.pth') + src.Log.print_with_color(">>> Server training process total completion.", "green") + return + self.current_out_cluster_cursor = 0 + random.shuffle(self.out_cluster_order) + src.Log.print_with_color(f">>> New Global Round. Shuffled order: {self.out_cluster_order}", "green") + + next_out_idx = self.out_cluster_order[self.current_out_cluster_cursor] + #weight model outcluster tiếp + self.out_cluster_models[next_out_idx] = copy.deepcopy(self.out_cluster_models[out_idx]) + self.current_out_cluster_idx = next_out_idx + src.Log.print_with_color(f">>> Moving to Out-cluster {next_out_idx}", "yellow") + self.notify_clients(register=False) # FedAsync: W_new = (1-alpha)*W_old + alpha*W_received def fedasync_aggregate(self, out_idx, in_cluster_sd, alpha=1.0): @@ -396,42 +404,71 @@ def notify_clients(self, start=True, register=True): self.send_to_response(node[0], pickle.dumps({"action": "STOP", "message": "Stop!", "parameters": None})) return - if register: - pass - else: - o_idx = self.current_out_cluster_idx - full_sd = self.out_cluster_models[o_idx] - cut = int(self.list_cut_layers[0][0] if isinstance(self.list_cut_layers[0], list) else self.list_cut_layers[0]) - klass = globals()[f'{self.model_name}_{self.data_name}'] - - src.Log.print_with_color(f">>> Starting Training Round for Out-cluster {o_idx}...", "red") - - # Notify cho những thằng ở o_idx hiện tại thoi - for role in set(n[1:] for n in self.list_clients): - layer_id, in_idx, out_idx, idx = role - if not ((layer_id == 1 and out_idx == o_idx) or layer_id > 1): - continue + o_idx = self.current_out_cluster_idx + full_sd = self.out_cluster_models[o_idx] + cut = int(self.list_cut_layers[0][0] if isinstance(self.list_cut_layers[0], list) else self.list_cut_layers[0]) - client_id = next(n[0] for n in reversed(self.list_clients) if n[1:] == role) - layers = [0, cut] if layer_id == 1 else [cut, -1] - model = klass(end_layer=cut) if layer_id == 1 else klass(start_layer=cut) + # Dynamic model class lookup + model_name = self.model_name + data_name = self.data_name + + if model_name in ['Bert', 'BERT']: + klass = Bert_AGNEWS + + elif model_name == 'KWT': + klass = KWT_SPEECHCOMMANDS + else: + klass_name = f"{model_name}_{data_name}" + klass = globals().get(klass_name) - state_dict = model.state_dict() - if len(full_sd) > 0: - for key in state_dict.keys(): - if key in full_sd: - state_dict[key] = full_sd[key] + if klass is None: + self.server_logger.log_error(f"Model class for {model_name} and {data_name} not found.") + return - label = self.label_[idx] if (layer_id == 1 and self.label_) else [] + src.Log.print_with_color(f">>> Starting Training Round for Out-cluster {o_idx}...", "red") - response = {"action": "START", "message": "Training Start", "parameters": state_dict, - "layers": layers, "model_name": self.model_name, "data_name": self.data_name, - "batch_size": self.batch_size, "lr": self.lr, "momentum": self.momentum, - "label_count": label, "local_round": self.local_round, "cluster": in_idx, - "out_cluster_id": o_idx} + # Notify active clients + for role in set(n[1:] for n in self.list_clients): + layer_id, in_idx, out_idx, idx = role + if not ((layer_id == 1 and out_idx == o_idx) or layer_id > 1): + continue - src.Log.print_with_color(f">>> Notifying client {client_id} (layer {layer_id}) for out-cluster {o_idx}, in-cluster {in_idx}", "yellow") - self.send_to_response(client_id, pickle.dumps(response)) + client_id = next(n[0] for n in reversed(self.list_clients) if n[1:] == role) + layers = [0, cut] if layer_id == 1 else [cut, -1] + + # Initialize partial model to get correct state_dict keys + if klass == Bert_AGNEWS: + if layer_id == 1: + model = klass(layer_id=1, n_block=cut) + else: + model = klass(layer_id=2, n_block=12 - cut) + else: + if layer_id == 1: + model = klass(end_layer=cut) + else: + model = klass(start_layer=cut) + + state_dict = model.state_dict() + if len(full_sd) > 0: + for key in state_dict.keys(): + if key in full_sd: + state_dict[key] = full_sd[key] + elif layer_id == 2 and klass == Bert_AGNEWS: + # Bert_AGNEWS layer 2 uses offset keys + offset_key = f"bert.encoder.layer.{int(key.split('.')[3]) + cut}.{'.'.join(key.split('.')[4:])}" if "bert.encoder.layer" in key else key + if offset_key in full_sd: + state_dict[key] = full_sd[offset_key] + + label = self.label_[idx] if (layer_id == 1 and self.label_) else [] + + response = {"action": "START", "message": "Training Start", "parameters": state_dict, + "layers": layers, "model_name": self.model_name, "data_name": self.data_name, + "batch_size": self.batch_size, "lr": self.lr, "momentum": self.momentum, + "label_count": label, "local_round": self.local_round, "cluster": in_idx, + "out_cluster_id": o_idx} + + src.Log.print_with_color(f">>> Notifying client {client_id} (layer {layer_id}) for out-cluster {o_idx}, in-cluster {in_idx}", "yellow") + self.send_to_response(client_id, pickle.dumps(response)) def start(self): self.channel.start_consuming() @@ -442,19 +479,5 @@ def send_to_response(self, client_id, message): src.Log.print_with_color(f"[>>>] Sent notification to client {client_id}", "red") self.reply_channel.basic_publish(exchange='', routing_key=reply_queue_name, body=message) - def avg_all_parameters(self): - layer_sizes = self.global_client_sizes - layer_params = self.global_model_parameters - for layer_idx, list_state_dicts in enumerate(layer_params): - list_sizes = layer_sizes[layer_idx] - if not list_state_dicts or not list_sizes: - self.avg_state_dict.append({}) - continue - avg_sd = src.Utils.fedavg_state_dicts(list_state_dicts, weights=list_sizes) - self.avg_state_dict.append(avg_sd) - - def concatenate_and_avg_clusters(self): - full_dict = {} - for sd in self.avg_state_dict: - full_dict.update(copy.deepcopy(sd)) - return full_dict \ No newline at end of file + + diff --git a/other/2LS/src/Validation.py b/other/2LS/src/Validation.py index fb40cb9..6e9feb5 100644 --- a/other/2LS/src/Validation.py +++ b/other/2LS/src/Validation.py @@ -1,65 +1,72 @@ - import numpy as np import math from tqdm import tqdm - +import torch import torchvision import torchvision.transforms as transforms import torch.nn.functional as F +from src.dataset.dataloader import data_loader from src.model import * def test(model_name, data_name, state_dict_full, logger): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if data_name == "MNIST": - transform_test = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) - ]) - testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test) + + # Use centralized data_loader + try: + test_loader = data_loader(data_name=data_name, train=False) + except ValueError as e: + logger.log_error(str(e)) + return False - elif data_name == "CIFAR10": - transform_test = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + # Dynamic model class lookup + if model_name in ['Bert', 'BERT']: + if data_name == 'AGNEWS': + klass = Bert_AGNEWS + else: + klass = BERT_EMOTION + elif model_name == 'KWT': + klass = KWT_SPEECHCOMMANDS else: - raise ValueError(f"Data name '{data_name}' is not valid.") - - test_loader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) - - - klass = globals()[f'{model_name}_{data_name}'] + # globals() might not contain all models if they are only in src.model + # but we did 'from src.model import *' + klass_name = f"{model_name}_{data_name}" + klass = globals().get(klass_name) if klass is None: - raise ValueError(f"Class '{model_name}' does not exist.") + logger.log_error(f"Model class for {model_name} and {data_name} not found.") + return False + # Initialize full model model = klass() model.load_state_dict(state_dict_full) model = model.to(device) + # evaluation mode model.eval() test_loss = 0 correct = 0 - for data, target in tqdm(test_loader): - data = data.to(device) - target = target.to(device) - output = model(data) - test_loss += F.nll_loss(output, target, reduction='sum').item() - pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability - correct += pred.eq(target.data.view_as(pred)).long().cpu().sum() + with torch.no_grad(): + for data, target in tqdm(test_loader): + data = data.to(device) + target = target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability + correct += pred.eq(target.data.view_as(pred)).long().cpu().sum() test_loss /= len(test_loader.dataset) accuracy = 100.0 * correct / len(test_loader.dataset) - print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( - test_loss, correct, len(test_loader.dataset), accuracy)) + + report = 'Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), accuracy) + + print(report) if np.isnan(test_loss) or math.isnan(test_loss) or abs(test_loss) > 10e5: return False else: - logger.log_info('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( - test_loss, correct, len(test_loader.dataset), accuracy)) + logger.log_info(report) return True diff --git a/other/2LS/src/dataset/AGNEWS.py b/other/2LS/src/dataset/AGNEWS.py new file mode 100644 index 0000000..fabd126 --- /dev/null +++ b/other/2LS/src/dataset/AGNEWS.py @@ -0,0 +1,30 @@ +import torch + +class AGNEWS_DATASET(torch.utils.data.Dataset): + def __init__(self, texts, labels, tokenizer, max_length=128): + self.texts = texts + self.labels = labels + self.tokenizer = tokenizer + self.max_length = max_length + + def __len__(self): + return len(self.texts) + + def __getitem__(self, idx): + text = str(self.texts[idx]) + label = self.labels[idx] + + # Tokenize + encoding = self.tokenizer( + text, + truncation=True, + padding='max_length', + max_length=self.max_length, + return_tensors='pt' + ) + + return { + 'input_ids': encoding['input_ids'].flatten(), + 'attention_mask': encoding['attention_mask'].flatten(), + 'labels': torch.tensor(label, dtype=torch.long) + } diff --git a/other/2LS/src/dataset/SPEECHCOMMANDS.py b/other/2LS/src/dataset/SPEECHCOMMANDS.py new file mode 100644 index 0000000..b40ea08 --- /dev/null +++ b/other/2LS/src/dataset/SPEECHCOMMANDS.py @@ -0,0 +1,222 @@ +import os +import random +import torch +import numpy as np +import urllib.request +import tarfile +import warnings +from scipy.io.wavfile import WavFileWarning +warnings.filterwarnings("ignore", category=WavFileWarning) +from torch.utils.data import Dataset +from scipy.io import wavfile +from scipy.fftpack import dct + +CLASSES = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go', 'silence', 'unknown'] + +def compute_mfcc(waveform, sample_rate=16000, n_mfcc=40, n_fft=480, hop_length=160, n_mels=40): + """Compute MFCC features using scipy (no torchaudio backend needed)""" + emphasized = np.append(waveform[0], waveform[1:] - 0.97 * waveform[:-1]) + + frame_length = n_fft + num_frames = 1 + (len(emphasized) - frame_length) // hop_length + frames = np.zeros((num_frames, frame_length)) + for i in range(num_frames): + frames[i] = emphasized[i * hop_length: i * hop_length + frame_length] + + frames *= np.hamming(frame_length) + + mag_frames = np.absolute(np.fft.rfft(frames, n_fft)) + pow_frames = (1.0 / n_fft) * (mag_frames ** 2) + + low_freq_mel = 0 + high_freq_mel = 2595 * np.log10(1 + (sample_rate / 2) / 700) + mel_points = np.linspace(low_freq_mel, high_freq_mel, n_mels + 2) + hz_points = 700 * (10 ** (mel_points / 2595) - 1) + bin_points = np.floor((n_fft + 1) * hz_points / sample_rate).astype(int) + + fbank = np.zeros((n_mels, int(np.floor(n_fft / 2 + 1)))) + for m in range(1, n_mels + 1): + f_m_minus = bin_points[m - 1] + f_m = bin_points[m] + f_m_plus = bin_points[m + 1] + for k in range(f_m_minus, f_m): + fbank[m - 1, k] = (k - f_m_minus) / (f_m - f_m_minus) + for k in range(f_m, f_m_plus): + fbank[m - 1, k] = (f_m_plus - k) / (f_m_plus - f_m) + + filter_banks = np.dot(pow_frames, fbank.T) + filter_banks = np.where(filter_banks == 0, np.finfo(float).eps, filter_banks) + filter_banks = 20 * np.log10(filter_banks) + + mfcc = dct(filter_banks, type=2, axis=1, norm='ortho')[:, :n_mfcc] + + return mfcc.T + + +def _download_and_extract(root_dir): + download_dir = os.path.dirname(root_dir) # ./data/SpeechCommands + os.makedirs(download_dir, exist_ok=True) + + tar_path = os.path.join(download_dir, "speech_commands_v0.02.tar.gz") + url = "http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz" + + if not os.path.exists(root_dir): + print(f"Dataset not found at {root_dir}. Downloading from {url}...") + try: + from tqdm import tqdm + + def progress_callback(count, block_size, total_size): + if not hasattr(progress_callback, 't'): + progress_callback.t = tqdm(total=total_size, unit='B', unit_scale=True, desc="Downloading SpeechCommands") + progress_callback.t.update(block_size) + if count * block_size >= total_size: + progress_callback.t.close() + + urllib.request.urlretrieve(url, tar_path, reporthook=progress_callback) + print(f"Downloaded to {tar_path}. Extracting...") + with tarfile.open(tar_path, 'r:gz') as tar: + tar.extractall(path=root_dir) + print(f"Extracted to {root_dir}") + except Exception as e: + raise RuntimeError(f"Failed to download/extract dataset: {e}") + +def _load_background_noise(root): + """Load all background noise files from _background_noise_ directory""" + noise_dir = os.path.join(root, '_background_noise_') + noise_data = [] + if not os.path.exists(noise_dir): + print(f"[WARNING] Background noise dir not found: {noise_dir}") + return noise_data + + for f in os.listdir(noise_dir): + if not f.endswith('.wav'): + continue + fpath = os.path.join(noise_dir, f) + try: + sr, wav = wavfile.read(fpath) + if wav.dtype == np.int16: + wav = wav.astype(np.float32) / 32768.0 + elif wav.dtype == np.int32: + wav = wav.astype(np.float32) / 2147483648.0 + else: + wav = wav.astype(np.float32) + noise_data.append(wav) + except Exception as e: + print(f"[WARNING] Error reading noise file {fpath}: {e}") + + print(f"Loaded {len(noise_data)} background noise files") + return noise_data + +class SpeechCommandsDataset(Dataset): + def __init__(self, root='./data', subset='training', n_silence=2300): + self.root = os.path.join(root, 'SpeechCommands', 'speech_commands_v0.02') + self.subset = subset + self.samples = [] + self.noise_data = [] + + if not os.path.exists(self.root): + _download_and_extract(self.root) + + self.noise_data = _load_background_noise(self.root) + + val_list = set() + test_list = set() + + val_file = os.path.join(self.root, 'validation_list.txt') + test_file = os.path.join(self.root, 'testing_list.txt') + + if os.path.exists(val_file): + with open(val_file) as f: + val_list = set(line.strip() for line in f) + if os.path.exists(test_file): + with open(test_file) as f: + test_list = set(line.strip() for line in f) + + for label_dir in os.listdir(self.root): + label_path = os.path.join(self.root, label_dir) + if not os.path.isdir(label_path) or label_dir.startswith('_'): + continue + + for audio_file in os.listdir(label_path): + if not audio_file.endswith('.wav'): + continue + + rel_path = os.path.join(label_dir, audio_file) + + if subset == 'validation' and rel_path in val_list: + self.samples.append((os.path.join(label_path, audio_file), label_dir)) + elif subset == 'testing' and rel_path in test_list: + self.samples.append((os.path.join(label_path, audio_file), label_dir)) + elif subset == 'training' and rel_path not in val_list and rel_path not in test_list: + self.samples.append((os.path.join(label_path, audio_file), label_dir)) + + if self.noise_data: + if subset == 'training': + num_silence = n_silence + else: + num_silence = max(1, n_silence // 9) + + for _ in range(num_silence): + self.samples.append((None, 'silence')) + + print(f"Added {num_silence} silence samples from background noise") + + # Pre-calculate labels for RpcClient efficiency + self.labels = [] + for _, label in self.samples: + if label in CLASSES: + self.labels.append(CLASSES.index(label)) + else: + self.labels.append(CLASSES.index('unknown')) + + print(f"Total {len(self.samples)} samples for {subset}") + + def _get_silence_waveform(self): + target_length = 16000 + noise = random.choice(self.noise_data) + + if len(noise) <= target_length: + waveform = np.pad(noise, (0, max(0, target_length - len(noise)))) + else: + start = random.randint(0, len(noise) - target_length) + waveform = noise[start: start + target_length] + + return waveform + + def __getitem__(self, idx): + audio_path, label = self.samples[idx] + + try: + if audio_path is None: + waveform = self._get_silence_waveform() + else: + sample_rate, waveform = wavfile.read(audio_path) + if waveform.dtype == np.int16: + waveform = waveform.astype(np.float32) / 32768.0 + elif waveform.dtype == np.int32: + waveform = waveform.astype(np.float32) / 2147483648.0 + else: + waveform = waveform.astype(np.float32) + + target_length = 16000 + if len(waveform) < target_length: + waveform = np.pad(waveform, (0, target_length - len(waveform))) + else: + waveform = waveform[:target_length] + + mfcc = compute_mfcc(waveform, sample_rate=16000, n_mfcc=40) + mfcc = torch.tensor(mfcc, dtype=torch.float32) + + except Exception as e: + print(f"[WARNING] Error processing sample {idx}: {e}, using zeros") + mfcc = torch.zeros(40, 98, dtype=torch.float32) + + if label in CLASSES: + label_idx = CLASSES.index(label) + else: + label_idx = CLASSES.index('unknown') + + return mfcc, label_idx + + def __len__(self): + return len(self.samples) \ No newline at end of file diff --git a/other/2LS/src/dataset/__init__.py b/other/2LS/src/dataset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/other/2LS/src/dataset/dataloader.py b/other/2LS/src/dataset/dataloader.py new file mode 100644 index 0000000..6518f7c --- /dev/null +++ b/other/2LS/src/dataset/dataloader.py @@ -0,0 +1,171 @@ +import random + +import torch +import torchvision +from collections import defaultdict +from tqdm import tqdm + +from torch.utils.data import DataLoader +import torchvision.transforms as transforms +from transformers import BertTokenizer +from datasets import load_dataset + +from src.dataset.AGNEWS import AGNEWS_DATASET +from src.dataset.SPEECHCOMMANDS import SpeechCommandsDataset + +def AGNEWS(batch_size=None, distribution=None, train=True): + cache_dir = './hf_cache' + print(f"Loading AGNEWS dataset with cache_dir={cache_dir}...") + dataset = load_dataset( + 'ag_news', + download_mode='reuse_dataset_if_exists', + cache_dir=cache_dir + ) + print("Dataset loaded successfully.") + tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + + if train: + train_data = dataset['train'] + train_target_counts = {k: v for k, v in enumerate(distribution)} + train_by_class = defaultdict(list) + for text, label in zip(train_data['text'], train_data['label']): + train_by_class[label].append((text, label)) + + train_texts, train_labels = [], [] + for label, count in train_target_counts.items(): + samples = random.sample(train_by_class[label], count) + train_texts.extend([t for t, _ in samples]) + train_labels.extend([l for _, l in samples]) + print("Train samples:", len(train_texts), {l: train_labels.count(l) for l in set(train_labels)}) + + train_set = AGNEWS_DATASET(train_texts, train_labels, tokenizer, max_length=128) + train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) + return train_loader + else: + test_data = dataset['test'] + distribution = [500, 500, 500, 500] + test_target_counts = {k: v for k, v in enumerate(distribution)} + test_by_class = defaultdict(list) + for text, label in zip(test_data['text'], test_data['label']): + test_by_class[label].append((text, label)) + + test_texts, test_labels = [], [] + for label, count in test_target_counts.items(): + samples = random.sample(test_by_class[label], count) + test_texts.extend([t for t, _ in samples]) + test_labels.extend([l for _, l in samples]) + + print("Test samples:", len(test_texts), {l: test_labels.count(l) for l in set(test_labels)}) + + test_set = AGNEWS_DATASET(test_texts, test_labels, tokenizer, max_length=128) + test_loader = DataLoader(test_set, batch_size=100, shuffle=False) + return test_loader + +def CIFAR10(batch_size=None, distribution=None, train=True): + if train: + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + + label_to_indices = defaultdict(list) + for idx, (_, label) in tqdm(enumerate(train_set)): + label_to_indices[int(label)].append(idx) + + selected_indices = [] + for label, count in enumerate(distribution): + selected_indices.extend(random.sample(label_to_indices[label], count)) + subset = torch.utils.data.Subset(train_set, selected_indices) + + train_loader = torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=True) + + return train_loader + else: + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + test_loader = torch.utils.data.DataLoader(test_set, batch_size=100, shuffle=False, num_workers=1) + return test_loader + +def MNIST(batch_size=None, distribution=None, train=True): + if train: + transform_train = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_train) + + label_to_indices = defaultdict(list) + for idx, (_, label) in tqdm(enumerate(train_set)): + label_to_indices[int(label)].append(idx) + + selected_indices = [] + for label, count in enumerate(distribution): + selected_indices.extend(random.sample(label_to_indices[label], count)) + subset = torch.utils.data.Subset(train_set, selected_indices) + + train_loader = torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=True) + + return train_loader + else: + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test) + test_loader = torch.utils.data.DataLoader(test_set, batch_size=100, shuffle=False, num_workers=1) + return test_loader + +def SPEECHCOMMANDS(batch_size=None, distribution=None, train=True): + """Google Speech Commands V2 dataset loader for KWT model""" + if train: + dataset = SpeechCommandsDataset(root='./data', subset='training') + + if distribution is not None: + # Build label index from samples list directly (avoid reading wav files) + from src.dataset.SPEECHCOMMANDS import CLASSES + label_to_indices = defaultdict(list) + for idx, (audio_path, label_name) in enumerate(dataset.samples): + if label_name in CLASSES: + label_idx = CLASSES.index(label_name) + else: + label_idx = CLASSES.index('unknown') + label_to_indices[label_idx].append(idx) + + selected_indices = [] + for label, count in enumerate(distribution): + if count > 0 and label in label_to_indices: + available = label_to_indices[label] + selected_indices.extend(random.sample(available, min(count, len(available)))) + + print(f"[DEBUG] Selected {len(selected_indices)} samples after distribution filter") + subset = torch.utils.data.Subset(dataset, selected_indices) + train_loader = DataLoader(subset, batch_size=batch_size, shuffle=True) + else: + train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + + return train_loader + else: + dataset = SpeechCommandsDataset(root='./data', subset='testing') + test_loader = DataLoader(dataset, batch_size=20, shuffle=False) + return test_loader + +def data_loader(data_name=None, batch_size=None, distribution=None, train=True): + if data_name == 'AGNEWS': + data = AGNEWS(batch_size, distribution, train) + elif data_name == 'SPEECHCOMMANDS': + data = SPEECHCOMMANDS(batch_size, distribution, train) + elif data_name == 'MNIST': + data = MNIST(batch_size, distribution, train) + elif data_name == 'CIFAR10': + data = CIFAR10(batch_size, distribution, train) + else: + raise ValueError(f"Dataset {data_name} not supported.") + + return data \ No newline at end of file diff --git a/other/2LS/src/model/Bert_AGNEWS.py b/other/2LS/src/model/Bert_AGNEWS.py new file mode 100644 index 0000000..4afde61 --- /dev/null +++ b/other/2LS/src/model/Bert_AGNEWS.py @@ -0,0 +1,231 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class DotDict(dict): + def __getattr__(self, k): + try: return self[k] + except KeyError: raise AttributeError(k) + def __setattr__(self, k, v): self[k] = v + def __delattr__(self, k): del self[k] + +# AG_NEWS have 4 layers: World, Sports, Business, Sci/Tech +class BertEmbeddings(nn.Module): + def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, dropout_prob): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) + self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, input_ids, token_type_ids=None): + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + +class BertSdpaSelfAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, dropout_prob): + super(BertSdpaSelfAttention, self).__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(hidden_size, self.all_head_size) + self.key = nn.Linear(hidden_size, self.all_head_size) + self.value = nn.Linear(hidden_size, self.all_head_size) + self.dropout = nn.Dropout(dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + import math + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + +class BertSelfOutput(nn.Module): + def __init__(self, hidden_size, dropout_prob): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + +class BertAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, dropout_prob): + super(BertAttention, self).__init__() + self.self = BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob) + self.output = BertSelfOutput(hidden_size, dropout_prob) + + def forward(self, hidden_states): + self_output = self.self(hidden_states) + attention_output = self.output(self_output, hidden_states) + return attention_output + +class BertIntermediate(nn.Module): + def __init__(self, hidden_size, intermediate_size): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear(hidden_size, intermediate_size) + self.intermediate_act_fn = nn.GELU() + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + +class BertOutput(nn.Module): + def __init__(self, hidden_size, intermediate_size, dropout_prob): + super(BertOutput, self).__init__() + self.dense = nn.Linear(intermediate_size, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + +class BertLayer(nn.Module): + def __init__(self, hidden_size, num_attention_heads, intermediate_size, dropout_prob): + super(BertLayer, self).__init__() + self.attention = BertAttention(hidden_size, num_attention_heads, dropout_prob) + self.intermediate = BertIntermediate(hidden_size, intermediate_size) + self.output = BertOutput(hidden_size, intermediate_size, dropout_prob) + + def forward(self, hidden_states): + attention_output = self.attention(hidden_states) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + +class BertPooler(nn.Module): + def __init__(self, hidden_size): + super(BertPooler, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + +class BertClassifier(nn.Module): + def __init__(self, hidden_size, num_labels, dropout_prob=0.1): + super(BertClassifier, self).__init__() + self.dropout = nn.Dropout(dropout_prob) + self.classifier = nn.Linear(hidden_size, num_labels) + + def forward(self, pooled_output): + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + return logits + +class Bert_AGNEWS(nn.Module): + def __init__( self, vocab_size=28996, hidden_size=768, num_attention_heads=12, intermediate_size=3072, + max_position_embeddings=512, type_vocab_size=2, dropout_prob=0.1, layer_id=0, n_block=12 + ): + super(Bert_AGNEWS, self).__init__() + self.layer_id = layer_id + self.config = DotDict( + model_type="bert", + vocab_size=vocab_size, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, intermediate_size=intermediate_size, + max_position_embeddings=max_position_embeddings, + bos_token_id=101, eos_token_id=102, pad_token_id=0, + is_encoder_decoder=False, tie_word_embeddings=False, + use_return_dict=True, output_attentions=False, output_hidden_states=False + ) + + if self.layer_id == 1: + self.embeddings = BertEmbeddings(vocab_size=vocab_size, hidden_size=hidden_size, max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size,dropout_prob=dropout_prob) + self.layers = nn.ModuleList( + [BertLayer(hidden_size, num_attention_heads, intermediate_size, dropout_prob) + for _ in range(n_block)] + ) + elif self.layer_id == 2: + self.layers = nn.ModuleList( + [BertLayer(hidden_size, num_attention_heads, intermediate_size, dropout_prob) + for _ in range(n_block)] + ) + self.pooler = BertPooler(hidden_size) + self.dropout = nn.Dropout(dropout_prob) + self.classifier = nn.Linear(hidden_size, 4) + else: + self.embeddings = BertEmbeddings(vocab_size=vocab_size, hidden_size=hidden_size, + max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size, dropout_prob=dropout_prob) + self.layers = nn.ModuleList( + [BertLayer(hidden_size, num_attention_heads, intermediate_size, dropout_prob) + for _ in range(n_block)] + ) + self.pooler = BertPooler(hidden_size) + self.dropout = nn.Dropout(dropout_prob) + self.classifier = nn.Linear(hidden_size, 4) + + def forward(self, input_ids, token_type_ids=None,**kwargs): + + if self.layer_id == 1: + x = self.embeddings(input_ids, token_type_ids) + for encode in self.layers: + x = encode(x) + elif self.layer_id == 2: + x = input_ids + for encode in self.layers: + x = encode(x) + x = self.pooler(x) + x = self.dropout(x) + x = self.classifier(x) + else: + x = self.embeddings(input_ids, token_type_ids) + for encode in self.layers: + x = encode(x) + x = self.pooler(x) + x = self.dropout(x) + x = self.classifier(x) + + return x diff --git a/other/2LS/src/model/KWT_SPEECHCOMMANDS.py b/other/2LS/src/model/KWT_SPEECHCOMMANDS.py new file mode 100644 index 0000000..648dd50 --- /dev/null +++ b/other/2LS/src/model/KWT_SPEECHCOMMANDS.py @@ -0,0 +1,94 @@ +import torch +import torch.nn as nn + + +class TransformerEncoderBlock(nn.Module): + def __init__(self, embed_dim, num_heads=1, mlp_dim=256): + super().__init__() + self.ln1 = nn.LayerNorm(embed_dim) + self.mha = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) + self.ln2 = nn.LayerNorm(embed_dim) + self.mlp = nn.Sequential( + nn.Linear(embed_dim, mlp_dim), + nn.GELU(), + nn.Linear(mlp_dim, embed_dim) + ) + + def forward(self, x): + _x = self.ln1(x) + x_attn = self.mha(_x, _x, _x)[0] + x = x + x_attn + x_mlp = self.mlp(self.ln2(x)) + x = x + x_mlp + return x + + +class KWT_SPEECHCOMMANDS(nn.Module): + def __init__(self, start_layer=0, end_layer=17): + super().__init__() + self.start_layer = start_layer + self.end_layer = 17 if end_layer == -1 else end_layer + + n_mfcc = 40 + time_steps = 98 + embed_dim = 64 + num_heads = 1 + mlp_dim = 256 + num_classes = 12 + dropout = 0.1 + + if self.start_layer < 1 <= self.end_layer: + self.layer1 = nn.Linear(n_mfcc, embed_dim) + + if self.start_layer < 2 <= self.end_layer: + self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) + + if self.start_layer < 3 <= self.end_layer: + self.pos_embed = nn.Parameter(torch.randn(1, time_steps + 1, embed_dim)) + self.dropout = nn.Dropout(dropout) + + for i in range(12): + layer_idx = 4 + i + if self.start_layer < layer_idx <= self.end_layer: + setattr(self, f'layer{layer_idx}', + TransformerEncoderBlock(embed_dim, num_heads, mlp_dim)) + + if self.start_layer < 16 <= self.end_layer: + self.layer16 = nn.LayerNorm(embed_dim) + + if self.start_layer < 17 <= self.end_layer: + self.layer17 = nn.Linear(embed_dim, num_classes) + + self._init_weights() + + def _init_weights(self): + if hasattr(self, 'cls_token'): + nn.init.trunc_normal_(self.cls_token, std=0.02) + if hasattr(self, 'pos_embed'): + nn.init.trunc_normal_(self.pos_embed, std=0.02) + + def forward(self, x): + if self.start_layer < 1 <= self.end_layer: + x = x.transpose(1, 2) + x = self.layer1(x) + + if self.start_layer < 2 <= self.end_layer: + cls_token = self.cls_token.expand(x.size(0), -1, -1) + x = torch.cat([cls_token, x], dim=1) + + if self.start_layer < 3 <= self.end_layer: + x = x + self.pos_embed + x = self.dropout(x) + + for i in range(12): + layer_idx = 4 + i + if self.start_layer < layer_idx <= self.end_layer: + x = getattr(self, f'layer{layer_idx}')(x) + + if self.start_layer < 16 <= self.end_layer: + x = self.layer16(x[:, 0]) + + if self.start_layer < 17 <= self.end_layer: + x = self.layer17(x) + + return x \ No newline at end of file diff --git a/other/2LS/src/model/__init__.py b/other/2LS/src/model/__init__.py index b0304dd..50d24da 100644 --- a/other/2LS/src/model/__init__.py +++ b/other/2LS/src/model/__init__.py @@ -4,3 +4,6 @@ from .VGG16_MNIST import * from .ViT_CIFAR10 import * from .ViT_MNIST import * +from .Bert_AGNEWS import * +from .BERT_EMOTION import * +from .KWT_SPEECHCOMMANDS import * diff --git a/other/2LS/src/train/Bert.py b/other/2LS/src/train/Bert.py new file mode 100644 index 0000000..6c5c916 --- /dev/null +++ b/other/2LS/src/train/Bert.py @@ -0,0 +1,167 @@ +import time +import pickle +from tqdm import tqdm + +import torch +import torch.nn as nn + +import src.Log + +class Train_Bert: + def __init__(self, client_id, layer_id, channel, device): + self.client_id = client_id + self.layer_id = layer_id + self.channel = channel + self.device = device + self.data_count = 0 + self.size = None + + def send_intermediate_output(self, output, labels, trace, cluster=None): + + forward_queue_name = f'intermediate_queue_{self.layer_id}_{cluster}' + self.channel.queue_declare(forward_queue_name, durable=False) + + if trace: + trace.append(self.client_id) + message = pickle.dumps( + {"data": output.detach().cpu().numpy(), "label": labels.cpu(), "trace": trace} + ) + else: + message = pickle.dumps( + {"data": output.detach().cpu().numpy(), "label": labels.cpu(), "trace": [self.client_id]} + ) + if self.size is None: + self.size = len(message) + print(f'Length message: {self.size} (bytes).') + + self.channel.basic_publish( + exchange='', + routing_key=forward_queue_name, + body=message + ) + + def send_gradient(self, gradient, trace): + to_client_id = trace[-1] + trace.pop(-1) + backward_queue_name = f'gradient_queue_{self.layer_id - 1}_{to_client_id}' + self.channel.queue_declare(queue=backward_queue_name, durable=False) + + message = pickle.dumps( + {"data": gradient.detach().cpu().numpy(), "trace": trace}) + + if self.size is None: + self.size = len(message) + print(f'Length message: {self.size} (bytes).') + self.channel.basic_publish( + exchange='', + routing_key=backward_queue_name, + body=message + ) + + def send_to_server(self, message): + self.channel.queue_declare('rpc_queue', durable=False) + self.channel.basic_publish(exchange='', + routing_key='rpc_queue', + body=pickle.dumps(message)) + + def train_on_first_layer(self, model, learning, train_loader=None, cluster=0): + optimizer = torch.optim.AdamW(model.parameters(), lr=learning["learning-rate"], weight_decay=learning["weight-decay"]) + + backward_queue_name = f'gradient_queue_{self.layer_id}_{self.client_id}' + self.channel.queue_declare(queue=backward_queue_name, durable=False) + self.channel.basic_qos(prefetch_count=1) + model = model.to(self.device) + + for batch in tqdm(train_loader, desc="Fine tuning"): + model.train() + optimizer.zero_grad() + + input_ids = batch['input_ids'].to(self.device) + labels = batch['labels'].to(self.device) + + intermediate_output = model(input_ids=input_ids) + intermediate_output = intermediate_output.detach().requires_grad_(True) + + self.data_count += 1 + + self.send_intermediate_output(intermediate_output, labels, trace=None, cluster=cluster) + + while True: + method_frame, header_frame, body = self.channel.basic_get(queue=backward_queue_name, auto_ack=True) + if method_frame and body: + received_data = pickle.loads(body) + gradient_numpy = received_data["data"] + gradient = torch.tensor(gradient_numpy).to(self.device) + + output = model(input_ids=input_ids) + output.backward(gradient=gradient) + optimizer.step() + break + + else: + time.sleep(0.5) + + notify_data = {"action": "NOTIFY", "client_id": self.client_id, "layer_id": self.layer_id, + "message": "Finish training!", "cluster": cluster} + + src.Log.print_with_color("[>>>] Finish training!", "red") + self.send_to_server(notify_data) + + while True: + broadcast_queue_name = f'reply_{self.client_id}' + method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + if body: + received_data = pickle.loads(body) + src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue") + if received_data["action"] == "PAUSE": + return True, self.data_count, received_data["send"] + time.sleep(0.5) + + def train_on_last_layer(self, model, learning, cluster=0): + optimizer = torch.optim.AdamW(model.parameters(), lr=learning["learning-rate"], weight_decay=learning["weight-decay"]) + criterion = nn.CrossEntropyLoss() + result = True + + forward_queue_name = f'intermediate_queue_{self.layer_id - 1}_{cluster}' + self.channel.queue_declare(queue=forward_queue_name, durable=False) + self.channel.basic_qos(prefetch_count=1) + print('Waiting for intermediate output. To exit press CTRL+C') + model.to(self.device) + model.train() + while True: + method_frame, header_frame, body = self.channel.basic_get(queue=forward_queue_name, auto_ack=True) + if method_frame and body: + optimizer.zero_grad() + received_data = pickle.loads(body) + intermediate_output_numpy = received_data["data"] + trace = received_data["trace"] + labels = received_data["label"].to(self.device) + intermediate_output = torch.tensor(intermediate_output_numpy, requires_grad=True).float().to(self.device) + + output = model(input_ids=intermediate_output) + + loss = criterion(output, labels) + + if torch.isnan(loss).any(): + src.Log.print_with_color("NaN detected in loss", "yellow") + result = False + + print(f"Loss: {loss.item()}") + intermediate_output.retain_grad() + loss.backward() + + optimizer.step() + self.data_count += 1 + + gradient = intermediate_output.grad + self.send_gradient(gradient, trace) + + else: + broadcast_queue_name = f'reply_{self.client_id}' + method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + if body: + received_data = pickle.loads(body) + src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue") + if received_data["action"] == "PAUSE": + return result, self.data_count, received_data["send"] + time.sleep(0.5) diff --git a/other/2LS/src/train/KWT.py b/other/2LS/src/train/KWT.py new file mode 100644 index 0000000..e58177f --- /dev/null +++ b/other/2LS/src/train/KWT.py @@ -0,0 +1,165 @@ +import time +import pickle +from tqdm import tqdm + +import torch +import torch.optim as optim +import torch.nn as nn + +import src.Log + + +class Train_KWT: + + def __init__(self, client_id, layer_id, channel, device): + self.client_id = client_id + self.layer_id = layer_id + self.channel = channel + self.device = device + self.data_count = 0 + + def send_intermediate_output(self, output, labels, trace, cluster=None): + forward_queue_name = f'intermediate_queue_{self.layer_id}_{cluster}' + self.channel.queue_declare(forward_queue_name, durable=False) + + if trace: + trace.append(self.client_id) + message = pickle.dumps( + {"data": output.detach().cpu().numpy(), "label": labels, "trace": trace} + ) + else: + message = pickle.dumps( + {"data": output.detach().cpu().numpy(), "label": labels, "trace": [self.client_id]} + ) + + self.channel.basic_publish( + exchange='', + routing_key=forward_queue_name, + body=message + ) + + def send_gradient(self, gradient, trace): + to_client_id = trace[-1] + trace.pop(-1) + backward_queue_name = f'gradient_queue_{self.layer_id - 1}_{to_client_id}' + self.channel.queue_declare(queue=backward_queue_name, durable=False) + + message = pickle.dumps( + {"data": gradient.detach().cpu().numpy(), "trace": trace, "test": False}) + + self.channel.basic_publish( + exchange='', + routing_key=backward_queue_name, + body=message + ) + + def send_to_server(self, message): + self.channel.queue_declare('rpc_queue', durable=False) + self.channel.basic_publish(exchange='', + routing_key='rpc_queue', + body=pickle.dumps(message)) + + def train_on_first_layer(self, model, learning, train_loader=None, cluster=None): + optimizer = optim.AdamW(model.parameters(), lr=learning["learning-rate"], weight_decay=learning["weight-decay"]) + + backward_queue_name = f'gradient_queue_{self.layer_id}_{self.client_id}' + self.channel.queue_declare(queue=backward_queue_name, durable=False) + self.channel.basic_qos(prefetch_count=1) + model.to(self.device) + + backward_queue_name = f'gradient_queue_{self.layer_id}_{self.client_id}' + self.channel.queue_declare(queue=backward_queue_name, durable=False) + self.channel.basic_qos(prefetch_count=1) + model.to(self.device) + + for batch in tqdm(train_loader, desc="Training"): + model.train() + optimizer.zero_grad() + training_data, labels = batch + training_data = training_data.to(self.device) + intermediate_output = model(training_data) + intermediate_output = intermediate_output.detach().requires_grad_(True) + + self.data_count += 1 + + self.send_intermediate_output(intermediate_output, labels, trace=None, cluster=cluster) + + while True: + method_frame, header_frame, body = self.channel.basic_get(queue=backward_queue_name, auto_ack=True) + if method_frame and body: + received_data = pickle.loads(body) + gradient_numpy = received_data["data"] + gradient = torch.tensor(gradient_numpy).to(self.device) + + output = model(training_data) + output.backward(gradient=gradient) + optimizer.step() + break + + else: + time.sleep(0.5) + + notify_data = {"action": "NOTIFY", "client_id": self.client_id, "layer_id": self.layer_id, + "message": "Finish training!", "cluster": cluster} + + src.Log.print_with_color("[>>>] Finish training!", "red") + self.send_to_server(notify_data) + + broadcast_queue_name = f'reply_{self.client_id}' + while True: + method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + if body: + received_data = pickle.loads(body) + src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue") + if received_data["action"] == "PAUSE": + return True, self.data_count, received_data["send"] + time.sleep(0.5) + + def train_on_last_layer(self, model, learning, cluster): + optimizer = optim.AdamW(model.parameters(), lr=learning["learning-rate"], weight_decay=learning["weight-decay"]) + result = True + + criterion = nn.CrossEntropyLoss() + forward_queue_name = f'intermediate_queue_{self.layer_id - 1}_{cluster}' + + self.channel.queue_declare(queue=forward_queue_name, durable=False) + self.channel.basic_qos(prefetch_count=1) + print('Waiting for intermediate output. To exit press CTRL+C') + model.to(self.device) + + while True: + model.train() + optimizer.zero_grad() + + method_frame, header_frame, body = self.channel.basic_get(queue=forward_queue_name, auto_ack=True) + if method_frame and body: + received_data = pickle.loads(body) + intermediate_output_numpy = received_data["data"] + trace = received_data["trace"] + labels = received_data["label"].to(self.device) + + intermediate_output = torch.tensor(intermediate_output_numpy, requires_grad=True).to(self.device) + + output = model(intermediate_output) + + loss = criterion(output, labels) + print(f"Loss: {loss.item()}") + if torch.isnan(loss).any(): + src.Log.print_with_color("NaN detected in loss", "yellow") + result = False + + intermediate_output.retain_grad() + loss.backward() + optimizer.step() + self.data_count += 1 + gradient = intermediate_output.grad + self.send_gradient(gradient, trace) + else: + broadcast_queue_name = f'reply_{self.client_id}' + method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + if body: + received_data = pickle.loads(body) + src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue") + if received_data["action"] == "PAUSE": + return result, self.data_count, received_data["send"] + time.sleep(0.5) diff --git a/other/2LS/src/train/VGG16.py b/other/2LS/src/train/VGG16.py new file mode 100644 index 0000000..731f1b6 --- /dev/null +++ b/other/2LS/src/train/VGG16.py @@ -0,0 +1,160 @@ +import time +import pickle +from tqdm import tqdm + +import torch +import torch.optim as optim +import torch.nn as nn + +import src.Log + +class Train_VGG16: + def __init__(self, client_id, layer_id, channel, device): + self.client_id = client_id + self.layer_id = layer_id + self.channel = channel + self.device = device + self.data_count = 0 + + def send_intermediate_output(self, output, labels, trace, test=False, cluster=None): + forward_queue_name = f'intermediate_queue_{self.layer_id}_{cluster}' + self.channel.queue_declare(forward_queue_name, durable=False) + + if trace: + trace.append(self.client_id) + message = pickle.dumps( + {"data": output.detach().cpu().numpy(), "label": labels, "trace": trace, + "test": test} + ) + else: + message = pickle.dumps( + {"data": output.detach().cpu().numpy(), "label": labels, "trace": [self.client_id], + "test": test} + ) + + self.channel.basic_publish( + exchange='', + routing_key=forward_queue_name, + body=message + ) + + def send_gradient(self, gradient, trace): + to_client_id = trace[-1] + trace.pop(-1) + backward_queue_name = f'gradient_queue_{self.layer_id - 1}_{to_client_id}' + self.channel.queue_declare(queue=backward_queue_name, durable=False) + + message = pickle.dumps( + {"data": gradient.detach().cpu().numpy(), "trace": trace, "test": False}) + + self.channel.basic_publish( + exchange='', + routing_key=backward_queue_name, + body=message + ) + + def send_to_server(self, message): + self.channel.queue_declare('rpc_queue', durable=False) + self.channel.basic_publish(exchange='', + routing_key='rpc_queue', + body=pickle.dumps(message)) + + def train_on_first_layer(self, model, learning, train_loader=None, cluster=None): + optimizer = optim.SGD(model.parameters(), lr=learning["learning-rate"], momentum=learning["momentum"]) + + backward_queue_name = f'gradient_queue_{self.layer_id}_{self.client_id}' + self.channel.queue_declare(queue=backward_queue_name, durable=False) + self.channel.basic_qos(prefetch_count=1) + model.to(self.device) + + for batch in tqdm(train_loader, desc="Training"): + model.train() + optimizer.zero_grad() + training_data, labels = batch + training_data = training_data.to(self.device) + intermediate_output = model(training_data) + intermediate_output = intermediate_output.detach().requires_grad_(True) + + self.data_count += 1 + + self.send_intermediate_output(intermediate_output, labels, trace=None, test=False, cluster=cluster) + + while True: + method_frame, header_frame, body = self.channel.basic_get(queue=backward_queue_name, auto_ack=True) + if method_frame and body: + received_data = pickle.loads(body) + gradient_numpy = received_data["data"] + gradient = torch.tensor(gradient_numpy).to(self.device) + + output = model(training_data) + output.backward(gradient=gradient) + optimizer.step() + break + + else: + time.sleep(0.5) + + notify_data = {"action": "NOTIFY", "client_id": self.client_id, "layer_id": self.layer_id, + "message": "Finish training!"} + + src.Log.print_with_color("[>>>] Finish training!", "red") + self.send_to_server(notify_data) + + broadcast_queue_name = f'reply_{self.client_id}' + while True: + method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + if body: + received_data = pickle.loads(body) + src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue") + if received_data["action"] == "PAUSE": + return True , self.data_count, received_data["send"] + time.sleep(0.5) + + def train_on_last_layer(self, model, learning, cluster): + optimizer = optim.SGD(model.parameters(), lr=learning["learning-rate"], momentum=learning["momentum"]) + result = True + + criterion = nn.CrossEntropyLoss() + forward_queue_name = f'intermediate_queue_{self.layer_id - 1}_{cluster}' + + self.channel.queue_declare(queue=forward_queue_name, durable=False) + self.channel.basic_qos(prefetch_count=1) + print('Waiting for intermediate output. To exit press CTRL+C') + model.to(self.device) + + while True: + model.train() + optimizer.zero_grad() + method_frame, header_frame, body = self.channel.basic_get(queue=forward_queue_name, auto_ack=True) + if method_frame and body: + received_data = pickle.loads(body) + intermediate_output_numpy = received_data["data"] + trace = received_data["trace"] + labels = received_data["label"].to(self.device) + + intermediate_output = torch.tensor(intermediate_output_numpy, requires_grad=True).to(self.device) + + output = model(intermediate_output) + + loss = criterion(output, labels) + print(f"Loss: {loss.item()}") + if torch.isnan(loss).any(): + src.Log.print_with_color("NaN detected in loss", "yellow") + result = False + + intermediate_output.retain_grad() + loss.backward() + optimizer.step() + self.data_count += 1 + gradient = intermediate_output.grad + self.send_gradient(gradient, trace) + + else: + broadcast_queue_name = f'reply_{self.client_id}' + method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + if body: + received_data = pickle.loads(body) + src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue") + if received_data["action"] == "PAUSE": + return result, self.data_count, received_data["send"] + time.sleep(0.5) \ No newline at end of file diff --git a/other/2LS/src/train/__init__.py b/other/2LS/src/train/__init__.py new file mode 100644 index 0000000..72ac2aa --- /dev/null +++ b/other/2LS/src/train/__init__.py @@ -0,0 +1 @@ +from .VGG16 import * \ No newline at end of file diff --git a/other/2LS/src/val/Bert.py b/other/2LS/src/val/Bert.py new file mode 100644 index 0000000..7622d26 --- /dev/null +++ b/other/2LS/src/val/Bert.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn +from tqdm import tqdm + +from src.model.Bert_AGNEWS import Bert_AGNEWS +from src.dataset.dataloader import data_loader + +def val_Bert(data_name, state_dict_full, logger): + criterion = nn.CrossEntropyLoss() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + test_loader = data_loader(data_name=data_name,train=False) + model = Bert_AGNEWS() + model = model.to(device) + model.load_state_dict(state_dict_full) + + model.eval() + correct, total, total_loss = 0, 0, 0 + + with torch.no_grad(): + for batch in tqdm(test_loader): + input_ids = batch['input_ids'].to(device) + labels = batch['labels'].to(device) + + logits = model(input_ids) + loss = criterion(logits, labels) + total_loss += loss.item() + correct += (logits.argmax(1) == labels).sum().item() + total += labels.size(0) + + acc = correct / total + avg_loss = total_loss / len(test_loader) + + print(f"Test Loss: {avg_loss:.2f}; Test Acc: {acc:.2f}") + + logger.log_info(f"Test Loss: {avg_loss:.2f}; Test Acc: {acc:.2f}") + + + + + + + + diff --git a/other/2LS/src/val/KWT.py b/other/2LS/src/val/KWT.py new file mode 100644 index 0000000..be0e004 --- /dev/null +++ b/other/2LS/src/val/KWT.py @@ -0,0 +1,39 @@ +from tqdm import tqdm +import torch +import torch.nn as nn + +from src.dataset.dataloader import data_loader +from src.model.KWT_SPEECHCOMMANDS import KWT_SPEECHCOMMANDS + +def val_KWT(data_name, state_dict_full, logger): + criterion = nn.CrossEntropyLoss() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(device) + test_loader = data_loader(data_name=data_name, train=False) + + model = KWT_SPEECHCOMMANDS() + model.load_state_dict(state_dict_full) + model.to(device) + model.eval() + + correct, total, total_loss = 0, 0, 0 + + with torch.no_grad(): + for mfcc, labels in tqdm(test_loader): + mfcc = mfcc.to(device) + labels = labels.to(device) + + outputs = model(mfcc) + loss = criterion(outputs, labels) + + total_loss += loss.item() + correct += (outputs.argmax(1) == labels).sum().item() + total += labels.size(0) + + acc = (correct / total) * 100 + avg_loss = total_loss / len(test_loader) + + print('Test set: Loss: {:.4f}; Accuracy: {}/{} ({:.2f}%)\n'.format(avg_loss, + correct, total, acc)) + logger.log_info('Test set: Loss: {:.4f}; Accuracy: {}/{} ({:.2f}%)\n'.format(avg_loss, + correct, total, acc)) \ No newline at end of file diff --git a/other/2LS/src/val/VGG16.py b/other/2LS/src/val/VGG16.py new file mode 100644 index 0000000..0d0b1dc --- /dev/null +++ b/other/2LS/src/val/VGG16.py @@ -0,0 +1,39 @@ +import torch +import torch.nn as nn + +from src.model.VGG16_CIFAR10 import VGG16_CIFAR10 +from tqdm import tqdm +from src.dataset.dataloader import data_loader + +def val_VGG16(data_name, state_dict_full, logger): + criterion = nn.CrossEntropyLoss() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(device) + test_loader = data_loader(data_name=data_name, train=False) + + model = VGG16_CIFAR10() + model = model.to(device) + model.load_state_dict(state_dict_full) + model.eval() + + correct, total, total_loss = 0, 0, 0 + + with torch.no_grad(): + for images, labels in tqdm(test_loader): + images = images.to(device) + labels = labels.to(device) + + outputs = model(images) + loss = criterion(outputs, labels) + + total_loss += loss.item() + correct += (outputs.argmax(1) == labels).sum().item() + total += labels.size(0) + + acc = (correct / total) * 100 + avg_loss = total_loss / len(test_loader) + + print('Test set:Loss: {:.4f}; Accuracy: {}/{} ({:.2f}%)\n'.format(avg_loss, + correct, total, acc)) + logger.log_info('Test set:Loss: {:.4f}; Accuracy: {}/{} ({:.2f}%)\n'.format(avg_loss, + correct, total, acc)) \ No newline at end of file diff --git a/other/2LS/src/val/__init__.py b/other/2LS/src/val/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/other/2LS/src/val/get_val.py b/other/2LS/src/val/get_val.py new file mode 100644 index 0000000..c0026a8 --- /dev/null +++ b/other/2LS/src/val/get_val.py @@ -0,0 +1,17 @@ +from src.val.VGG16 import val_VGG16 +from src.val.Bert import val_Bert +from src.val.KWT import val_KWT + +def get_val(model_name, data_name, state_dict_full, logger): + if model_name == 'Bert': + val_Bert(data_name, state_dict_full, logger) + return True + elif model_name == 'VGG16': + val_VGG16(data_name, state_dict_full, logger) + return True + elif model_name == 'KWT': + val_KWT(data_name, state_dict_full, logger) + return True + else: + return False +