diff --git a/other/2LS/README.md b/other/2LS/README.md new file mode 100644 index 0000000..e7bfdd2 --- /dev/null +++ b/other/2LS/README.md @@ -0,0 +1,33 @@ +# How to run 2Ls + +## SERVER + +```commandline +python3 server.py +``` + +## Client +### dai +```commandline +python3 client.py --layer_id 2 --idx 0 --incluster 0 +python3 client.py --layer_id 2 --idx 1 --incluster 0 +python3 client.py --layer_id 2 --idx 2 --incluster 1 +``` +machine 12, 3, 8 +```commandline +python3.8 client.py --layer_id 1 --idx 0 --incluster 0 --outcluster 0 +python3 client.py --layer_id 1 --idx 1 --incluster 0 --outcluster 0 +python3 client.py --layer_id 1 --idx 2 --incluster 1 --outcluster 0 +``` +machine 4, 5, 9 +```commandline +python3 client.py --layer_id 1 --idx 0 --incluster 0 --outcluster 1 +python3 client.py --layer_id 1 --idx 1 --incluster 0 --outcluster 1 +python3 client.py --layer_id 1 --idx 2 --incluster 1 --outcluster 1 +``` +machine 6, 7, 10 +```commandline +python3 client.py --layer_id 1 --idx 0 --incluster 0 --outcluster 2 +python3 client.py --layer_id 1 --idx 1 --incluster 0 --outcluster 2 +python3 client.py --layer_id 1 --idx 2 --incluster 1 --outcluster 2 +``` diff --git a/other/2LS/client.py b/other/2LS/client.py new file mode 100644 index 0000000..41d7943 --- /dev/null +++ b/other/2LS/client.py @@ -0,0 +1,58 @@ +import pika +import uuid +import argparse +import yaml + +import torch + +import src.Log +from src.RpcClient import RpcClient + +parser = argparse.ArgumentParser(description="Split learning framework") +parser.add_argument('--layer_id', type=int, required=True, help='ID of layer, start from 1') +parser.add_argument('--device', type=str, required=False, help='Device of client') + +parser.add_argument('--idx', type=int, required=True, help='index of client') +parser.add_argument('--incluster', type=int, required=False, default=-1, help='In-cluster ID') +parser.add_argument('--outcluster', type=int, required=False, default=-1, help='Out-cluster ID') +args = parser.parse_args() + +with open('config.yaml', 'r') as file: + config = yaml.safe_load(file) + +client_id = uuid.uuid4() +address = config["rabbit"]["address"] +username = config["rabbit"]["username"] +password = config["rabbit"]["password"] +virtual_host = config["rabbit"]["virtual-host"] + +device = None +if args.device is None: + if torch.cuda.is_available(): + device = "cuda" + print(f"Using device: {torch.cuda.get_device_name(device)}") + else: + device = "cpu" + print(f"Using device: CPU") +else: + device = args.device + print(f"Using device: {device}") + +credentials = pika.PlainCredentials(username, password) +connection = pika.BlockingConnection(pika.ConnectionParameters(address, 5672, f'{virtual_host}', credentials)) +channel = connection.channel() + +in_cluster_id = args.incluster +out_cluster_id = args.outcluster +idx = args.idx + +if __name__ == "__main__": + src.Log.print_with_color("[>>>] Client sending registration message to server...", "red") + + data = {"action": "REGISTER", "client_id": client_id, "idx": idx, "layer_id": args.layer_id, + "in_cluster_id": in_cluster_id, "out_cluster_id": out_cluster_id, "message": "Hello from Client!"} + + client = RpcClient(client_id, args.layer_id, channel, device, in_cluster_id, idx) + client.send_to_server(data) + client.wait_response() + diff --git a/other/2LS/config.yaml b/other/2LS/config.yaml new file mode 100644 index 0000000..6ea0c43 --- /dev/null +++ b/other/2LS/config.yaml @@ -0,0 +1,34 @@ +name: Split Learning +server: + global-round: 1 + clients: + - 9 + - 3 + num-cluster: 3 + cut-layer: 2 + info-cluster: + - [2,1] + - [2,1] + - [2,1] + model: BERT + data-name: AGNEWS + data-distribution: + non-iid: False + num-sample: 500 + num-label: 4 + random-seed: 1 + +rabbit: + address: 127.0.0.1 + username: admin + password: admin + virtual-host: / + +log_path: . +debug_mode: True + +learning: + learning-rate: 0.00005 + weight-decay : 0.01 + momentum: 0.5 + batch-size: 4 diff --git a/other/2LS/server.py b/other/2LS/server.py new file mode 100644 index 0000000..dadba5f --- /dev/null +++ b/other/2LS/server.py @@ -0,0 +1,30 @@ +import argparse +import sys +import signal +from src.Server import Server +from src.Utils import delete_old_queues +import yaml + +parser = argparse.ArgumentParser(description="Split learning framework with controller.") + +args = parser.parse_args() + +with open('config.yaml') as file: + config = yaml.safe_load(file) +address = config["rabbit"]["address"] +username = config["rabbit"]["username"] +password = config["rabbit"]["password"] +virtual_host = config["rabbit"]["virtual-host"] + + +def signal_handler(sig, frame): + print("\nCatch stop signal Ctrl+C. Stop the program.") + delete_old_queues(address, username, password, virtual_host) + sys.exit(0) + + +if __name__ == "__main__": + signal.signal(signal.SIGINT, signal_handler) + delete_old_queues(address, username, password, virtual_host) + server = Server(config) + server.start() diff --git a/other/2LS/src/Log.py b/other/2LS/src/Log.py new file mode 100644 index 0000000..f485cbc --- /dev/null +++ b/other/2LS/src/Log.py @@ -0,0 +1,66 @@ +import logging +import os + + +class Colors: + COLORS = { + "header": '\033[95m', + "blue": '\033[94m', + "green": '\033[92m', + "yellow": '\033[93m', + "red": '\033[91m', + "end": '\033[0m' + } + + +class Logger: + def __init__(self, log_path, debug_mode=False, minimal=False): + # Thiết lập logger với tên "my_logger" + self.logger = logging.getLogger("my_logger") + self.logger.setLevel(logging.DEBUG) # Mức log + self.debug_mode = debug_mode + self.minimal = minimal + + # Clear existing handlers to avoid duplicate logs if re-initialized + if self.logger.hasHandlers(): + self.logger.handlers.clear() + + if log_path: + # Tạo thư mục log nếu chưa tồn tại + log_dir = os.path.dirname(log_path) + if log_dir: + os.makedirs(log_dir, exist_ok=True) + + # Tạo file handler để ghi log vào file + file_handler = logging.FileHandler(log_path) + file_handler.setLevel(logging.DEBUG) + + # Định dạng log + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + file_handler.setFormatter(formatter) + + # Gắn file handler vào logger + self.logger.addHandler(file_handler) + + def log_info(self, message): + if not self.minimal: + print(f"[INFO] {message}") + self.logger.info(message) + + def log_warning(self, message): + print_with_color(f"[WARN] {message}", "yellow") + self.logger.warning(message) + + def log_error(self, message): + print_with_color(f"[ERROR] {message}", "red") + self.logger.error(message) + + def log_debug(self, message): + if self.debug_mode: + print_with_color(f"[DEBUG] {message}", "green") + self.logger.debug(message) + + +def print_with_color(text, color): + color_code = Colors.COLORS.get(color.lower(), Colors.COLORS["end"]) + print(f"{color_code}{text}{Colors.COLORS['end']}") diff --git a/other/2LS/src/RpcClient.py b/other/2LS/src/RpcClient.py new file mode 100644 index 0000000..60357e3 --- /dev/null +++ b/other/2LS/src/RpcClient.py @@ -0,0 +1,138 @@ +import time +import pickle +import copy + +import src.Log + +from src.model.BERT_AGNEWS import BERT_AGNEWS +from src.model.VGG16_CIFAR10 import VGG16_CIFAR10 +from src.model.KWT_SPEECHCOMMANDS import KWT_SPEECHCOMMANDS +from src.train.VGG16 import Train_VGG16 +from src.train.BERT import Train_BERT +from src.train.KWT import Train_KWT +from src.dataset.dataloader import data_loader + +from peft import LoraConfig, get_peft_model + + +class RpcClient: + def __init__(self, client_id, layer_id, channel, device, in_cluster_id, idx): + self.client_id = client_id + self.layer_id = layer_id + self.channel = channel + self.device = device + self.in_cluster_id = in_cluster_id + self.idx = idx + + self.response = None + self.model = None + self.model_train = None + self.train_loader = None + self.label_count = None + self.peft_config = None + + self.train_set = None + self.label_to_indices = None + + def wait_response(self): + status = True + reply_queue_name = f'reply_{self.client_id}' + self.channel.queue_declare(reply_queue_name, durable=False) + while status: + method_frame, header_frame, body = self.channel.basic_get(queue=reply_queue_name, auto_ack=True) + if body: + status = self.response_message(body) + time.sleep(0.5) + + def response_message(self, body): + self.response = pickle.loads(body) + src.Log.print_with_color(f"[<<<] Client received: {self.response.get('message', 'No message')}", "blue") + action = self.response["action"] + + if action == "START": + state_dict = self.response["parameters"] + model_name = self.response["model_name"] + cut_layer = self.response['cut_layer'] + label_count = self.response['label_count'] + data_name = self.response["data_name"] + + if self.label_count is None: + self.label_count = label_count + + if self.label_count is not None: + src.Log.print_with_color(f"Label distribution of client: {self.label_count}", "yellow") + + if model_name == 'VGG16': + self.model_train = Train_VGG16(self.client_id, self.layer_id, self.channel, self.device, self.in_cluster_id, self.idx) + elif model_name == 'BERT': + self.model_train = Train_BERT(self.client_id, self.layer_id, self.channel, self.device, self.in_cluster_id, self.idx) + self.peft_config = LoraConfig( + task_type="SEQ_CLS", + r=8, lora_alpha=16, lora_dropout=0.1, + bias="none", + target_modules=["query", "key", "value", "dense"] + ) + else: + self.model_train = Train_KWT(self.client_id, self.layer_id, self.channel, self.device, self.in_cluster_id, self.idx) + + if self.model is None: + if model_name == 'VGG16': + klass = VGG16_CIFAR10 + elif model_name == 'KWT': + klass = KWT_SPEECHCOMMANDS + else: + klass = BERT_AGNEWS + + if self.layer_id == 1: + self.model = klass(end_layer=cut_layer) + else: + self.model = klass(start_layer=cut_layer) + + if state_dict: + self.model.load_state_dict(state_dict) + + learning = self.response["learning"] + batch_size = learning["batch-size"] + + if model_name == 'BERT': + self.model = get_peft_model(self.model, self.peft_config) + if self.layer_id == 2: + for param in self.model.layer15.classifier.parameters(): + param.requires_grad = True + + self.model.to(self.device) + + # Start training + if self.layer_id == 1: + if self.train_loader is None: + self.train_loader = data_loader(data_name, batch_size, self.label_count, train=True) + + result, size = self.model_train.train_on_first_layer(self.model, learning, self.train_loader) + + else: + result, size = self.model_train.train_on_last_layer(self.model, learning) + + if model_name == 'BERT': + self.model = self.model.merge_and_unload() + + model_state_dict = copy.deepcopy(self.model.state_dict()) + if self.device != "cpu": + for key in model_state_dict: + model_state_dict[key] = model_state_dict[key].to('cpu') + + data = {"action": "UPDATE", "client_id": self.client_id, "layer_id": self.layer_id, + "result": result, "size": size, "in_cluster_id": self.in_cluster_id, + "message": "Sent parameters to Server", "parameters": model_state_dict} + src.Log.print_with_color("[>>>] Client sent parameters to server", "red") + self.send_to_server(data) + return True + elif action == "STOP": + return False + return True + + + def send_to_server(self, message): + self.channel.queue_declare('rpc_queue', durable=False) + self.channel.basic_publish(exchange='', + routing_key='rpc_queue', + body=pickle.dumps(message)) diff --git a/other/2LS/src/Server.py b/other/2LS/src/Server.py new file mode 100644 index 0000000..8e11102 --- /dev/null +++ b/other/2LS/src/Server.py @@ -0,0 +1,326 @@ +import torch +import os +import random +import pika +import pickle +import sys +import numpy as np +import copy +import src.Log +import src.Utils +from src.val.get_val import get_val +from src.model.BERT_AGNEWS import BERT_AGNEWS +from src.model.KWT_SPEECHCOMMANDS import KWT_SPEECHCOMMANDS +from src.model.VGG16_CIFAR10 import VGG16_CIFAR10 + +class Server: + def __init__(self, config): + # RabbitMQ + address = config["rabbit"]["address"] + username = config["rabbit"]["username"] + password = config["rabbit"]["password"] + virtual_host = config["rabbit"]["virtual-host"] + + self.model_name = config["server"]["model"] + self.data_name = config["server"]["data-name"] + self.total_clients = config["server"]["clients"] + self.num_cluster = config["server"]["num-cluster"] + self.cut_layer = config["server"]["cut-layer"] + self.info_cluster = config["server"]["info-cluster"] + self.global_round = config["server"]["global-round"] + self.round = self.global_round + self.global_model = None + + # Clients + self.learning = config["learning"] + self.data_distribution = config["server"]["data-distribution"] + + # Data distribution + self.non_iid = self.data_distribution["non-iid"] + self.num_label = self.data_distribution["num-label"] + self.num_sample = self.data_distribution["num-sample"] + self.random_seed = config["server"]["random-seed"] + self.label_counts = None + + if self.random_seed: + random.seed(self.random_seed) + + log_path = config["log_path"] + + credentials = pika.PlainCredentials(username, password) + self.connection = pika.BlockingConnection( + pika.ConnectionParameters(address, 5672, f'{virtual_host}', credentials)) + self.channel = self.connection.channel() + self.channel.queue_declare(queue='rpc_queue') + + self.out_cluster_ids = list(range(self.num_cluster)) + self.count_notify = [] # list + self.count_update = [] # list + self.check_in_cluster = [] + self.in_params = None # list([[],[]]) + self.in_sizes = None # list([[],[]]) + self.register_clients = [0 for _ in range(len(self.total_clients))] + self.responses = {} # Save response + self.list_clients = [] + self.round_result = True + + self.current_out_cluster = 0 + self.full_state_dict = None + + self.channel.basic_qos(prefetch_count=1) + self.reply_channel = self.connection.channel() + self.channel.basic_consume(queue='rpc_queue', on_message_callback=self.on_request) + + debug_mode = config["debug_mode"] + + self.logger = src.Log.Logger(f"{log_path}/app.log", debug_mode) + + self.logger.log_info(f"Application start. Server is waiting for {self.total_clients} clients.") + + def distribution(self): + if self.non_iid: + if self.data_name == "SPEECHCOMMANDS": + label_distribution = np.array([[0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25], + [0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25], + [0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25], + ]) + else: + label_distribution = np.array([[0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], + [0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], + [0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1], + [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], + [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], + [0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.0, 0.1], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.1], + ]) + + self.label_counts = (label_distribution * self.num_sample).astype(int) + else: + self.label_counts = np.full((self.total_clients[0], self.num_label), self.num_sample // self.num_label) + + def on_request(self, ch, method, props, body): + message = pickle.loads(body) + routing_key = props.reply_to + action = message["action"] + client_id = message["client_id"] + layer_id = message["layer_id"] + + self.responses[routing_key] = message + + if action == "REGISTER": + in_cluster_id = message["in_cluster_id"] + out_cluster_id = message["out_cluster_id"] + idx = message["idx"] + + if (client_id, layer_id, in_cluster_id, out_cluster_id, idx) not in self.list_clients: + self.list_clients.append((client_id, layer_id, in_cluster_id, out_cluster_id, idx)) + + src.Log.print_with_color(f"[<<<] Received message from client: {message}", "blue") + self.register_clients[layer_id - 1] += 1 + + if self.register_clients == self.total_clients: + self.distribution() + self.set_up() + + self.logger.log_info(f"Start training round {self.global_round - self.round + 1}") + src.Log.print_with_color(f"Start training round {self.global_round - self.round + 1}", "yellow") + + self.count_notify = copy.deepcopy(self.info_cluster[0]) + self.count_update = [x * 2 for x in self.info_cluster[0]] + self.in_params = [[[],[]] for _ in range(len(self.info_cluster[0]))] + self.in_sizes = [[[],[]] for _ in range(len(self.info_cluster[0]))] + + src.Log.print_with_color(f"List out-cluster : {self.out_cluster_ids}", "yellow") + self.current_out_cluster = self.out_cluster_ids.pop(0) + src.Log.print_with_color(f"Start out-cluster {self.current_out_cluster}", "yellow") + self.notify_clients() + + elif action == "NOTIFY": + src.Log.print_with_color(f"[<<<] Received message from client: {message}", "blue") + in_cluster = message["in_cluster_id"] + + message = {"action": "PAUSE", + "message": "Pause training and please send your parameters"} + + self.count_notify[in_cluster] -= 1 + if self.count_notify[in_cluster] == 0: + src.Log.print_with_color(f"Received finish training notification clients from in cluster {in_cluster}.", "yellow") + for (client_id, layer_id, in_cluster_id , out_cluster_id, _, _) in self.list_clients: + if (out_cluster_id == self.current_out_cluster or layer_id == 2) and (in_cluster_id == in_cluster): + self.send_to_response(client_id, pickle.dumps(message)) + + elif action == "UPDATE": + data_message = message["message"] + result = message["result"] + model_state_dict = message["parameters"] + client_size = message["size"] + in_cluster = message["in_cluster_id"] + + if not result: + self.round_result = False + src.Log.print_with_color(f"[<<<] Received message from {client_id}: {data_message}", "blue") + + self.count_update[in_cluster] -= 1 + self.in_params[in_cluster][layer_id - 1].append(model_state_dict) + self.in_sizes[in_cluster][layer_id - 1].append(client_size) + + if self.count_update[in_cluster] == 0: + self.check_in_cluster.append(in_cluster) + + if len(self.check_in_cluster) == len(self.info_cluster[self.current_out_cluster]): + avg_in_cluster = self.avg_in_clusters() + + for num, check in enumerate(self.check_in_cluster): + alpha = float(1 / (1 + num)) + self.global_model = self.fed_async_aggregate(self.global_model, avg_in_cluster[check], alpha) + torch.save(self.global_model, f'{self.model_name}_{self.data_name}.pth') + + if len(self.out_cluster_ids) == 0: + if self.round_result: + # Test + if not get_val(self.model_name, self.data_name, self.global_model, self.logger): + self.logger.log_warning("Training failed!") + src.Log.print_with_color("Training failed!", "yellow") + self.round = 0 + else: + self.round -= 1 + else: + self.round = 0 + + if self.round > 0: + self.logger.log_info(f"Start training round {self.global_round - self.round + 1}") + + self.out_cluster_ids = list(range(self.num_cluster)) + random.shuffle(self.out_cluster_ids) + self.global_model = None + self.reset() + + src.Log.print_with_color(f"List out-cluster : {self.out_cluster_ids}", "yellow") + self.current_out_cluster = self.out_cluster_ids.pop(0) + self.notify_clients(True, self.current_out_cluster) + src.Log.print_with_color(f"Start out-cluster {self.current_out_cluster}", "yellow") + else: + self.logger.log_info("Stop training !!!") + self.notify_clients(start=False) + sys.exit() + else: + self.reset() + + self.current_out_cluster = self.out_cluster_ids.pop(0) + src.Log.print_with_color(f"Start out-cluster {self.current_out_cluster}", "yellow") + self.notify_clients(True, self.current_out_cluster) + + + ch.basic_ack(delivery_tag=method.delivery_tag) + + def fed_async_aggregate(self, out_cluster_sd, in_cluster_sd, alpha=1.0): + if out_cluster_sd is None: + out_cluster_sd = in_cluster_sd + src.Log.print_with_color(f">>> FedAsync Out-cluster {self.current_out_cluster} updated (alpha={alpha}).", "green") + else: + for key in in_cluster_sd.keys(): + out_cluster_sd[key] = (1.0 - alpha) * out_cluster_sd[key].float() + alpha * in_cluster_sd[key].float() + out_cluster_sd[key] = out_cluster_sd[key].to(in_cluster_sd[key].dtype) + src.Log.print_with_color(f">>> FedAsync Out-cluster {self.current_out_cluster} updated (alpha={alpha}).", "green") + return out_cluster_sd + + def notify_clients(self, start=True, out_id=0): + filepath = f'{self.model_name}_{self.data_name}.pth' + + for (client_id, layer_id, _, out_cluster_id, _, labels) in self.list_clients: + state_dict = None + if start: + if out_cluster_id == out_id or out_cluster_id == -1: + if os.path.exists(filepath): + self.full_state_dict = torch.load(filepath, weights_only=True) + + if self.model_name == 'VGG16': + klass = VGG16_CIFAR10 + elif self.model_name == "KWT": + klass = KWT_SPEECHCOMMANDS + else: + klass = BERT_AGNEWS + + if layer_id == 1: + model = klass(end_layer=self.cut_layer) + else: + model = klass(start_layer=self.cut_layer) + state_dict = model.state_dict() + keys = state_dict.keys() + + for key in keys: + state_dict[key] = self.full_state_dict[key] + + src.Log.print_with_color(f"Load model successfully", "green") + else: + src.Log.print_with_color(f"File {filepath} does not exist.", "yellow") + + src.Log.print_with_color(f"[>>>] Sent start training request to client {client_id}", "red") + + response = {"action": "START", + "message": "Server accept the connection!", + "parameters": state_dict, + "cut_layer": self.cut_layer, + "model_name": self.model_name, + "data_name": self.data_name, + "learning": self.learning, + "label_count": labels} + self.send_to_response(client_id, pickle.dumps(response)) + else: + continue + else: + src.Log.print_with_color(f"[>>>] Sent stop training request to client {client_id}", "red") + response = {"action": "STOP", + "message": "Stop training!"} + self.send_to_response(client_id, pickle.dumps(response)) + + def set_up(self): + self.label_counts = self.label_counts.tolist() + new_list_client = [] + for (client_id, layer_id, in_cluster_id, out_cluster_id, idx) in self.list_clients: + if layer_id == 1: + new_list_client.append((client_id, layer_id, in_cluster_id, out_cluster_id, idx, self.label_counts.pop(0))) + else: + new_list_client.append((client_id, layer_id, in_cluster_id, out_cluster_id, idx, [])) + + self.list_clients = new_list_client + + def start(self): + self.channel.start_consuming() + + def send_to_response(self, client_id, message): + reply_queue_name = f'reply_{client_id}' + self.reply_channel.queue_declare(reply_queue_name, durable=False) + src.Log.print_with_color(f"[>>>] Sent notification to client {client_id}", "red") + self.reply_channel.basic_publish(exchange='', routing_key=reply_queue_name, body=message) + + def avg_in_clusters(self): + avg_in_cluster = [] + + for i in range(len(self.in_params)): + list_params = self.in_params[i] + list_sizes = self.in_sizes[i] + full_dict = {} + + for idx, layer_dict in enumerate(list_params): + sd = src.Utils.fedavg_state_dicts(layer_dict, list_sizes[idx]) + full_dict.update(copy.deepcopy(sd)) + + avg_in_cluster.append(full_dict) + + return avg_in_cluster + + def reset(self): + self.check_in_cluster = [] + self.count_notify = copy.deepcopy(self.info_cluster[0]) + self.count_update = [x * 2 for x in self.info_cluster[0]] + self.in_params = [[[], []] for _ in range(len(self.info_cluster[0]))] + self.in_sizes = [[[], []] for _ in range(len(self.info_cluster[0]))] diff --git a/other/2LS/src/Utils.py b/other/2LS/src/Utils.py new file mode 100644 index 0000000..dc42a72 --- /dev/null +++ b/other/2LS/src/Utils.py @@ -0,0 +1,79 @@ +import numpy as np +import random +import pika +import torch + +from requests.auth import HTTPBasicAuth +import requests + + +def delete_old_queues(address, username, password, virtual_host): + url = f'http://{address}:15672/api/queues' + response = requests.get(url, auth=HTTPBasicAuth(username, password)) + + if response.status_code == 200: + queues = response.json() + + credentials = pika.PlainCredentials(username, password) + connection = pika.BlockingConnection(pika.ConnectionParameters(address, 5672, f'{virtual_host}', credentials)) + http_channel = connection.channel() + + for queue in queues: + queue_name = queue['name'] + if queue_name.startswith("reply") or queue_name.startswith("intermediate_queue") or queue_name.startswith( + "gradient_queue") or queue_name.startswith("rpc_queue"): + + http_channel.queue_delete(queue=queue_name) + + else: + http_channel.queue_purge(queue=queue_name) + + connection.close() + return True + else: + return False + +def fedavg_state_dicts(state_dicts, weights = None): + """ + Trung bình (FedAvg) một list các state_dict. + - state_dicts: list các dict {param_name: tensor} + - weights: list trọng số tương ứng (mặc định None nghĩa là mỗi model weight=1) + Trả về một dict {param_name: tensor_avg} + """ + num = len(state_dicts) + if num == 0: + raise ValueError("fedavg_state_dicts: không có state_dict nào để trung bình.") + + if weights is None: + weights = [1.0] * num + total_w = sum(weights) + + # Tập hợp tất cả key + all_keys = set().union(*(sd.keys() for sd in state_dicts)) + avg_dict = {} + + for key in all_keys: + # gom tensor + weight, xử lý NaN + acc = None + for sd, w in zip(state_dicts, weights): + if key not in sd: + continue + t = sd[key].float() + if torch.isnan(t).any(): + t = torch.nan_to_num(t) # zero-fill + t = t * w + acc = t if acc is None else acc + t + + # chia trung bình + avg = acc / total_w + + # cast về dtype gốc + orig = next(sd[key] for sd in state_dicts if key in sd) + if orig.dtype in (torch.int8, torch.int16, torch.int32, torch.int64, torch.bool): + avg = avg.round().to(orig.dtype) + else: + avg = avg.to(orig.dtype) + + avg_dict[key] = avg + + return avg_dict diff --git a/other/2LS/src/Validation.py b/other/2LS/src/Validation.py new file mode 100644 index 0000000..6e9feb5 --- /dev/null +++ b/other/2LS/src/Validation.py @@ -0,0 +1,72 @@ +import numpy as np +import math +from tqdm import tqdm +import torch +import torchvision +import torchvision.transforms as transforms +import torch.nn.functional as F + +from src.dataset.dataloader import data_loader +from src.model import * + +def test(model_name, data_name, state_dict_full, logger): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Use centralized data_loader + try: + test_loader = data_loader(data_name=data_name, train=False) + except ValueError as e: + logger.log_error(str(e)) + return False + + # Dynamic model class lookup + if model_name in ['Bert', 'BERT']: + if data_name == 'AGNEWS': + klass = Bert_AGNEWS + else: + klass = BERT_EMOTION + elif model_name == 'KWT': + klass = KWT_SPEECHCOMMANDS + else: + # globals() might not contain all models if they are only in src.model + # but we did 'from src.model import *' + klass_name = f"{model_name}_{data_name}" + klass = globals().get(klass_name) + + if klass is None: + logger.log_error(f"Model class for {model_name} and {data_name} not found.") + return False + + # Initialize full model + model = klass() + + model.load_state_dict(state_dict_full) + model = model.to(device) + + # evaluation mode + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in tqdm(test_loader): + data = data.to(device) + target = target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability + correct += pred.eq(target.data.view_as(pred)).long().cpu().sum() + + test_loss /= len(test_loader.dataset) + accuracy = 100.0 * correct / len(test_loader.dataset) + + report = 'Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), accuracy) + + print(report) + + if np.isnan(test_loss) or math.isnan(test_loss) or abs(test_loss) > 10e5: + return False + else: + logger.log_info(report) + + return True diff --git a/other/2LS/src/dataset/AGNEWS.py b/other/2LS/src/dataset/AGNEWS.py new file mode 100644 index 0000000..fabd126 --- /dev/null +++ b/other/2LS/src/dataset/AGNEWS.py @@ -0,0 +1,30 @@ +import torch + +class AGNEWS_DATASET(torch.utils.data.Dataset): + def __init__(self, texts, labels, tokenizer, max_length=128): + self.texts = texts + self.labels = labels + self.tokenizer = tokenizer + self.max_length = max_length + + def __len__(self): + return len(self.texts) + + def __getitem__(self, idx): + text = str(self.texts[idx]) + label = self.labels[idx] + + # Tokenize + encoding = self.tokenizer( + text, + truncation=True, + padding='max_length', + max_length=self.max_length, + return_tensors='pt' + ) + + return { + 'input_ids': encoding['input_ids'].flatten(), + 'attention_mask': encoding['attention_mask'].flatten(), + 'labels': torch.tensor(label, dtype=torch.long) + } diff --git a/other/2LS/src/dataset/SPEECHCOMMANDS.py b/other/2LS/src/dataset/SPEECHCOMMANDS.py new file mode 100644 index 0000000..b40ea08 --- /dev/null +++ b/other/2LS/src/dataset/SPEECHCOMMANDS.py @@ -0,0 +1,222 @@ +import os +import random +import torch +import numpy as np +import urllib.request +import tarfile +import warnings +from scipy.io.wavfile import WavFileWarning +warnings.filterwarnings("ignore", category=WavFileWarning) +from torch.utils.data import Dataset +from scipy.io import wavfile +from scipy.fftpack import dct + +CLASSES = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go', 'silence', 'unknown'] + +def compute_mfcc(waveform, sample_rate=16000, n_mfcc=40, n_fft=480, hop_length=160, n_mels=40): + """Compute MFCC features using scipy (no torchaudio backend needed)""" + emphasized = np.append(waveform[0], waveform[1:] - 0.97 * waveform[:-1]) + + frame_length = n_fft + num_frames = 1 + (len(emphasized) - frame_length) // hop_length + frames = np.zeros((num_frames, frame_length)) + for i in range(num_frames): + frames[i] = emphasized[i * hop_length: i * hop_length + frame_length] + + frames *= np.hamming(frame_length) + + mag_frames = np.absolute(np.fft.rfft(frames, n_fft)) + pow_frames = (1.0 / n_fft) * (mag_frames ** 2) + + low_freq_mel = 0 + high_freq_mel = 2595 * np.log10(1 + (sample_rate / 2) / 700) + mel_points = np.linspace(low_freq_mel, high_freq_mel, n_mels + 2) + hz_points = 700 * (10 ** (mel_points / 2595) - 1) + bin_points = np.floor((n_fft + 1) * hz_points / sample_rate).astype(int) + + fbank = np.zeros((n_mels, int(np.floor(n_fft / 2 + 1)))) + for m in range(1, n_mels + 1): + f_m_minus = bin_points[m - 1] + f_m = bin_points[m] + f_m_plus = bin_points[m + 1] + for k in range(f_m_minus, f_m): + fbank[m - 1, k] = (k - f_m_minus) / (f_m - f_m_minus) + for k in range(f_m, f_m_plus): + fbank[m - 1, k] = (f_m_plus - k) / (f_m_plus - f_m) + + filter_banks = np.dot(pow_frames, fbank.T) + filter_banks = np.where(filter_banks == 0, np.finfo(float).eps, filter_banks) + filter_banks = 20 * np.log10(filter_banks) + + mfcc = dct(filter_banks, type=2, axis=1, norm='ortho')[:, :n_mfcc] + + return mfcc.T + + +def _download_and_extract(root_dir): + download_dir = os.path.dirname(root_dir) # ./data/SpeechCommands + os.makedirs(download_dir, exist_ok=True) + + tar_path = os.path.join(download_dir, "speech_commands_v0.02.tar.gz") + url = "http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz" + + if not os.path.exists(root_dir): + print(f"Dataset not found at {root_dir}. Downloading from {url}...") + try: + from tqdm import tqdm + + def progress_callback(count, block_size, total_size): + if not hasattr(progress_callback, 't'): + progress_callback.t = tqdm(total=total_size, unit='B', unit_scale=True, desc="Downloading SpeechCommands") + progress_callback.t.update(block_size) + if count * block_size >= total_size: + progress_callback.t.close() + + urllib.request.urlretrieve(url, tar_path, reporthook=progress_callback) + print(f"Downloaded to {tar_path}. Extracting...") + with tarfile.open(tar_path, 'r:gz') as tar: + tar.extractall(path=root_dir) + print(f"Extracted to {root_dir}") + except Exception as e: + raise RuntimeError(f"Failed to download/extract dataset: {e}") + +def _load_background_noise(root): + """Load all background noise files from _background_noise_ directory""" + noise_dir = os.path.join(root, '_background_noise_') + noise_data = [] + if not os.path.exists(noise_dir): + print(f"[WARNING] Background noise dir not found: {noise_dir}") + return noise_data + + for f in os.listdir(noise_dir): + if not f.endswith('.wav'): + continue + fpath = os.path.join(noise_dir, f) + try: + sr, wav = wavfile.read(fpath) + if wav.dtype == np.int16: + wav = wav.astype(np.float32) / 32768.0 + elif wav.dtype == np.int32: + wav = wav.astype(np.float32) / 2147483648.0 + else: + wav = wav.astype(np.float32) + noise_data.append(wav) + except Exception as e: + print(f"[WARNING] Error reading noise file {fpath}: {e}") + + print(f"Loaded {len(noise_data)} background noise files") + return noise_data + +class SpeechCommandsDataset(Dataset): + def __init__(self, root='./data', subset='training', n_silence=2300): + self.root = os.path.join(root, 'SpeechCommands', 'speech_commands_v0.02') + self.subset = subset + self.samples = [] + self.noise_data = [] + + if not os.path.exists(self.root): + _download_and_extract(self.root) + + self.noise_data = _load_background_noise(self.root) + + val_list = set() + test_list = set() + + val_file = os.path.join(self.root, 'validation_list.txt') + test_file = os.path.join(self.root, 'testing_list.txt') + + if os.path.exists(val_file): + with open(val_file) as f: + val_list = set(line.strip() for line in f) + if os.path.exists(test_file): + with open(test_file) as f: + test_list = set(line.strip() for line in f) + + for label_dir in os.listdir(self.root): + label_path = os.path.join(self.root, label_dir) + if not os.path.isdir(label_path) or label_dir.startswith('_'): + continue + + for audio_file in os.listdir(label_path): + if not audio_file.endswith('.wav'): + continue + + rel_path = os.path.join(label_dir, audio_file) + + if subset == 'validation' and rel_path in val_list: + self.samples.append((os.path.join(label_path, audio_file), label_dir)) + elif subset == 'testing' and rel_path in test_list: + self.samples.append((os.path.join(label_path, audio_file), label_dir)) + elif subset == 'training' and rel_path not in val_list and rel_path not in test_list: + self.samples.append((os.path.join(label_path, audio_file), label_dir)) + + if self.noise_data: + if subset == 'training': + num_silence = n_silence + else: + num_silence = max(1, n_silence // 9) + + for _ in range(num_silence): + self.samples.append((None, 'silence')) + + print(f"Added {num_silence} silence samples from background noise") + + # Pre-calculate labels for RpcClient efficiency + self.labels = [] + for _, label in self.samples: + if label in CLASSES: + self.labels.append(CLASSES.index(label)) + else: + self.labels.append(CLASSES.index('unknown')) + + print(f"Total {len(self.samples)} samples for {subset}") + + def _get_silence_waveform(self): + target_length = 16000 + noise = random.choice(self.noise_data) + + if len(noise) <= target_length: + waveform = np.pad(noise, (0, max(0, target_length - len(noise)))) + else: + start = random.randint(0, len(noise) - target_length) + waveform = noise[start: start + target_length] + + return waveform + + def __getitem__(self, idx): + audio_path, label = self.samples[idx] + + try: + if audio_path is None: + waveform = self._get_silence_waveform() + else: + sample_rate, waveform = wavfile.read(audio_path) + if waveform.dtype == np.int16: + waveform = waveform.astype(np.float32) / 32768.0 + elif waveform.dtype == np.int32: + waveform = waveform.astype(np.float32) / 2147483648.0 + else: + waveform = waveform.astype(np.float32) + + target_length = 16000 + if len(waveform) < target_length: + waveform = np.pad(waveform, (0, target_length - len(waveform))) + else: + waveform = waveform[:target_length] + + mfcc = compute_mfcc(waveform, sample_rate=16000, n_mfcc=40) + mfcc = torch.tensor(mfcc, dtype=torch.float32) + + except Exception as e: + print(f"[WARNING] Error processing sample {idx}: {e}, using zeros") + mfcc = torch.zeros(40, 98, dtype=torch.float32) + + if label in CLASSES: + label_idx = CLASSES.index(label) + else: + label_idx = CLASSES.index('unknown') + + return mfcc, label_idx + + def __len__(self): + return len(self.samples) \ No newline at end of file diff --git a/other/2LS/src/dataset/__init__.py b/other/2LS/src/dataset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/other/2LS/src/dataset/dataloader.py b/other/2LS/src/dataset/dataloader.py new file mode 100644 index 0000000..6518f7c --- /dev/null +++ b/other/2LS/src/dataset/dataloader.py @@ -0,0 +1,171 @@ +import random + +import torch +import torchvision +from collections import defaultdict +from tqdm import tqdm + +from torch.utils.data import DataLoader +import torchvision.transforms as transforms +from transformers import BertTokenizer +from datasets import load_dataset + +from src.dataset.AGNEWS import AGNEWS_DATASET +from src.dataset.SPEECHCOMMANDS import SpeechCommandsDataset + +def AGNEWS(batch_size=None, distribution=None, train=True): + cache_dir = './hf_cache' + print(f"Loading AGNEWS dataset with cache_dir={cache_dir}...") + dataset = load_dataset( + 'ag_news', + download_mode='reuse_dataset_if_exists', + cache_dir=cache_dir + ) + print("Dataset loaded successfully.") + tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + + if train: + train_data = dataset['train'] + train_target_counts = {k: v for k, v in enumerate(distribution)} + train_by_class = defaultdict(list) + for text, label in zip(train_data['text'], train_data['label']): + train_by_class[label].append((text, label)) + + train_texts, train_labels = [], [] + for label, count in train_target_counts.items(): + samples = random.sample(train_by_class[label], count) + train_texts.extend([t for t, _ in samples]) + train_labels.extend([l for _, l in samples]) + print("Train samples:", len(train_texts), {l: train_labels.count(l) for l in set(train_labels)}) + + train_set = AGNEWS_DATASET(train_texts, train_labels, tokenizer, max_length=128) + train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) + return train_loader + else: + test_data = dataset['test'] + distribution = [500, 500, 500, 500] + test_target_counts = {k: v for k, v in enumerate(distribution)} + test_by_class = defaultdict(list) + for text, label in zip(test_data['text'], test_data['label']): + test_by_class[label].append((text, label)) + + test_texts, test_labels = [], [] + for label, count in test_target_counts.items(): + samples = random.sample(test_by_class[label], count) + test_texts.extend([t for t, _ in samples]) + test_labels.extend([l for _, l in samples]) + + print("Test samples:", len(test_texts), {l: test_labels.count(l) for l in set(test_labels)}) + + test_set = AGNEWS_DATASET(test_texts, test_labels, tokenizer, max_length=128) + test_loader = DataLoader(test_set, batch_size=100, shuffle=False) + return test_loader + +def CIFAR10(batch_size=None, distribution=None, train=True): + if train: + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + + label_to_indices = defaultdict(list) + for idx, (_, label) in tqdm(enumerate(train_set)): + label_to_indices[int(label)].append(idx) + + selected_indices = [] + for label, count in enumerate(distribution): + selected_indices.extend(random.sample(label_to_indices[label], count)) + subset = torch.utils.data.Subset(train_set, selected_indices) + + train_loader = torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=True) + + return train_loader + else: + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + test_loader = torch.utils.data.DataLoader(test_set, batch_size=100, shuffle=False, num_workers=1) + return test_loader + +def MNIST(batch_size=None, distribution=None, train=True): + if train: + transform_train = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_train) + + label_to_indices = defaultdict(list) + for idx, (_, label) in tqdm(enumerate(train_set)): + label_to_indices[int(label)].append(idx) + + selected_indices = [] + for label, count in enumerate(distribution): + selected_indices.extend(random.sample(label_to_indices[label], count)) + subset = torch.utils.data.Subset(train_set, selected_indices) + + train_loader = torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=True) + + return train_loader + else: + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test) + test_loader = torch.utils.data.DataLoader(test_set, batch_size=100, shuffle=False, num_workers=1) + return test_loader + +def SPEECHCOMMANDS(batch_size=None, distribution=None, train=True): + """Google Speech Commands V2 dataset loader for KWT model""" + if train: + dataset = SpeechCommandsDataset(root='./data', subset='training') + + if distribution is not None: + # Build label index from samples list directly (avoid reading wav files) + from src.dataset.SPEECHCOMMANDS import CLASSES + label_to_indices = defaultdict(list) + for idx, (audio_path, label_name) in enumerate(dataset.samples): + if label_name in CLASSES: + label_idx = CLASSES.index(label_name) + else: + label_idx = CLASSES.index('unknown') + label_to_indices[label_idx].append(idx) + + selected_indices = [] + for label, count in enumerate(distribution): + if count > 0 and label in label_to_indices: + available = label_to_indices[label] + selected_indices.extend(random.sample(available, min(count, len(available)))) + + print(f"[DEBUG] Selected {len(selected_indices)} samples after distribution filter") + subset = torch.utils.data.Subset(dataset, selected_indices) + train_loader = DataLoader(subset, batch_size=batch_size, shuffle=True) + else: + train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + + return train_loader + else: + dataset = SpeechCommandsDataset(root='./data', subset='testing') + test_loader = DataLoader(dataset, batch_size=20, shuffle=False) + return test_loader + +def data_loader(data_name=None, batch_size=None, distribution=None, train=True): + if data_name == 'AGNEWS': + data = AGNEWS(batch_size, distribution, train) + elif data_name == 'SPEECHCOMMANDS': + data = SPEECHCOMMANDS(batch_size, distribution, train) + elif data_name == 'MNIST': + data = MNIST(batch_size, distribution, train) + elif data_name == 'CIFAR10': + data = CIFAR10(batch_size, distribution, train) + else: + raise ValueError(f"Dataset {data_name} not supported.") + + return data \ No newline at end of file diff --git a/other/2LS/src/model/BERT_AGNEWS.py b/other/2LS/src/model/BERT_AGNEWS.py new file mode 100644 index 0000000..6819fa4 --- /dev/null +++ b/other/2LS/src/model/BERT_AGNEWS.py @@ -0,0 +1,219 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class DotDict(dict): + def __getattr__(self, k): + try: return self[k] + except KeyError: raise AttributeError(k) + def __setattr__(self, k, v): self[k] = v + def __delattr__(self, k): del self[k] + +# AG_NEWS have 4 layers: World, Sports, Business, Sci/Tech +class BertEmbeddings(nn.Module): + def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, dropout_prob): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) + self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, input_ids, token_type_ids=None): + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + +class BertSdpaSelfAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, dropout_prob): + super(BertSdpaSelfAttention, self).__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(hidden_size, self.all_head_size) + self.key = nn.Linear(hidden_size, self.all_head_size) + self.value = nn.Linear(hidden_size, self.all_head_size) + self.dropout = nn.Dropout(dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + import math + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + +class BertSelfOutput(nn.Module): + def __init__(self, hidden_size, dropout_prob): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + +class BertAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, dropout_prob): + super(BertAttention, self).__init__() + self.self = BertSdpaSelfAttention(hidden_size, num_attention_heads, dropout_prob) + self.output = BertSelfOutput(hidden_size, dropout_prob) + + def forward(self, hidden_states): + self_output = self.self(hidden_states) + attention_output = self.output(self_output, hidden_states) + return attention_output + +class BertIntermediate(nn.Module): + def __init__(self, hidden_size, intermediate_size): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear(hidden_size, intermediate_size) + self.intermediate_act_fn = nn.GELU() + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + +class BertOutput(nn.Module): + def __init__(self, hidden_size, intermediate_size, dropout_prob): + super(BertOutput, self).__init__() + self.dense = nn.Linear(intermediate_size, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + +class BertLayer(nn.Module): + def __init__(self, hidden_size, num_attention_heads, intermediate_size, dropout_prob): + super(BertLayer, self).__init__() + self.attention = BertAttention(hidden_size, num_attention_heads, dropout_prob) + self.intermediate = BertIntermediate(hidden_size, intermediate_size) + self.output = BertOutput(hidden_size, intermediate_size, dropout_prob) + + def forward(self, hidden_states): + attention_output = self.attention(hidden_states) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + +class BertPooler(nn.Module): + def __init__(self, hidden_size): + super(BertPooler, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + +class BertClassifier(nn.Module): + def __init__(self, hidden_size, num_labels, dropout_prob=0.1): + super(BertClassifier, self).__init__() + self.dropout = nn.Dropout(dropout_prob) + self.classifier = nn.Linear(hidden_size, num_labels) + + def forward(self, pooled_output): + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + return logits + + +class BERT_AGNEWS(nn.Module): + def __init__(self, vocab_size=28996, hidden_size=768, num_attention_heads=12, intermediate_size=3072, + max_position_embeddings=512, type_vocab_size=2, dropout_prob=0.1, n_block=12, + start_layer=0, end_layer=15): + super(BERT_AGNEWS, self).__init__() + self.start_layer = start_layer + self.end_layer = 15 if end_layer == -1 else end_layer + self.config = DotDict( + model_type="bert", + vocab_size=vocab_size, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, intermediate_size=intermediate_size, + max_position_embeddings=max_position_embeddings, + bos_token_id=101, eos_token_id=102, pad_token_id=0, + is_encoder_decoder=False, tie_word_embeddings=False, + use_return_dict=True, output_attentions=False, output_hidden_states=False + ) + + if self.start_layer < 1 <= self.end_layer: + self.layer1 = BertEmbeddings(vocab_size=vocab_size, hidden_size=hidden_size, + max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size, dropout_prob=dropout_prob) + + for i in range(12): + layer_idx = 2 + i + if self.start_layer < layer_idx <= self.end_layer: + setattr(self, f'layer{layer_idx}', + BertLayer(hidden_size, num_attention_heads, intermediate_size, dropout_prob)) + + if self.start_layer < 14 <= self.end_layer: + self.layer14 = BertPooler(hidden_size) + + if self.start_layer < 15 <= self.end_layer: + self.layer15 = BertClassifier(hidden_size, 4) + + def forward(self, input_ids=None, token_type_ids=None, **kwargs): + x = input_ids + if self.start_layer < 1 <= self.end_layer: + x = self.layer1(x, token_type_ids) + + for i in range(12): + layer_idx = 2 + i + if self.start_layer < layer_idx <= self.end_layer: + layer_module = getattr(self, f'layer{layer_idx}') + x = layer_module(x) + + if self.start_layer < 14 <= self.end_layer: + x = self.layer14(x) + + if self.start_layer < 15 <= self.end_layer: + x = self.layer15(x) + + return x diff --git a/other/2LS/src/model/KWT_SPEECHCOMMANDS.py b/other/2LS/src/model/KWT_SPEECHCOMMANDS.py new file mode 100644 index 0000000..648dd50 --- /dev/null +++ b/other/2LS/src/model/KWT_SPEECHCOMMANDS.py @@ -0,0 +1,94 @@ +import torch +import torch.nn as nn + + +class TransformerEncoderBlock(nn.Module): + def __init__(self, embed_dim, num_heads=1, mlp_dim=256): + super().__init__() + self.ln1 = nn.LayerNorm(embed_dim) + self.mha = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) + self.ln2 = nn.LayerNorm(embed_dim) + self.mlp = nn.Sequential( + nn.Linear(embed_dim, mlp_dim), + nn.GELU(), + nn.Linear(mlp_dim, embed_dim) + ) + + def forward(self, x): + _x = self.ln1(x) + x_attn = self.mha(_x, _x, _x)[0] + x = x + x_attn + x_mlp = self.mlp(self.ln2(x)) + x = x + x_mlp + return x + + +class KWT_SPEECHCOMMANDS(nn.Module): + def __init__(self, start_layer=0, end_layer=17): + super().__init__() + self.start_layer = start_layer + self.end_layer = 17 if end_layer == -1 else end_layer + + n_mfcc = 40 + time_steps = 98 + embed_dim = 64 + num_heads = 1 + mlp_dim = 256 + num_classes = 12 + dropout = 0.1 + + if self.start_layer < 1 <= self.end_layer: + self.layer1 = nn.Linear(n_mfcc, embed_dim) + + if self.start_layer < 2 <= self.end_layer: + self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) + + if self.start_layer < 3 <= self.end_layer: + self.pos_embed = nn.Parameter(torch.randn(1, time_steps + 1, embed_dim)) + self.dropout = nn.Dropout(dropout) + + for i in range(12): + layer_idx = 4 + i + if self.start_layer < layer_idx <= self.end_layer: + setattr(self, f'layer{layer_idx}', + TransformerEncoderBlock(embed_dim, num_heads, mlp_dim)) + + if self.start_layer < 16 <= self.end_layer: + self.layer16 = nn.LayerNorm(embed_dim) + + if self.start_layer < 17 <= self.end_layer: + self.layer17 = nn.Linear(embed_dim, num_classes) + + self._init_weights() + + def _init_weights(self): + if hasattr(self, 'cls_token'): + nn.init.trunc_normal_(self.cls_token, std=0.02) + if hasattr(self, 'pos_embed'): + nn.init.trunc_normal_(self.pos_embed, std=0.02) + + def forward(self, x): + if self.start_layer < 1 <= self.end_layer: + x = x.transpose(1, 2) + x = self.layer1(x) + + if self.start_layer < 2 <= self.end_layer: + cls_token = self.cls_token.expand(x.size(0), -1, -1) + x = torch.cat([cls_token, x], dim=1) + + if self.start_layer < 3 <= self.end_layer: + x = x + self.pos_embed + x = self.dropout(x) + + for i in range(12): + layer_idx = 4 + i + if self.start_layer < layer_idx <= self.end_layer: + x = getattr(self, f'layer{layer_idx}')(x) + + if self.start_layer < 16 <= self.end_layer: + x = self.layer16(x[:, 0]) + + if self.start_layer < 17 <= self.end_layer: + x = self.layer17(x) + + return x \ No newline at end of file diff --git a/other/2LS/src/model/MobileNetv1_CIFAR10.py b/other/2LS/src/model/MobileNetv1_CIFAR10.py new file mode 100644 index 0000000..e6f783a --- /dev/null +++ b/other/2LS/src/model/MobileNetv1_CIFAR10.py @@ -0,0 +1,185 @@ +import torch +import torch.nn as nn + +class MobileNetv1_CIFAR10(nn.Module): + def __init__(self, start_layer=0, end_layer=84): + super(MobileNetv1_CIFAR10, self).__init__() + self.start_layer = start_layer + self.end_layer = end_layer + + if start_layer < 1 <= end_layer: + self.layer1 = nn.Conv2d(3, 32, 3, 1, 1) + if start_layer < 2 <= end_layer: + self.layer2 = nn.BatchNorm2d(32) + if start_layer < 3 <= end_layer: + self.layer3 = nn.ReLU() + if start_layer < 4 <= end_layer: + self.layer4 = nn.Conv2d(32, 32, 3, 1, 1) + if start_layer < 5 <= end_layer: + self.layer5 = nn.BatchNorm2d(32) + if start_layer < 6 <= end_layer: + self.layer6 = nn.ReLU() + if start_layer < 7 <= end_layer: + self.layer7 = nn.Conv2d(32, 64, 1) + if start_layer < 8 <= end_layer: + self.layer8 = nn.BatchNorm2d(64) + if start_layer < 9 <= end_layer: + self.layer9 = nn.ReLU() + if start_layer < 10 <= end_layer: + self.layer10 = nn.Conv2d(64, 64, 3, 2, 1) + if start_layer < 11 <= end_layer: + self.layer11 = nn.BatchNorm2d(64) + if start_layer < 12 <= end_layer: + self.layer12 = nn.ReLU() + if start_layer < 13 <= end_layer: + self.layer13 = nn.Conv2d(64, 128, 1) + if start_layer < 14 <= end_layer: + self.layer14 = nn.BatchNorm2d(128) + if start_layer < 15 <= end_layer: + self.layer15 = nn.ReLU() + if start_layer < 16 <= end_layer: + self.layer16 = nn.Conv2d(128, 128, 3, 1, 1) + if start_layer < 17 <= end_layer: + self.layer17 = nn.BatchNorm2d(128) + if start_layer < 18 <= end_layer: + self.layer18 = nn.ReLU() + if start_layer < 19 <= end_layer: + self.layer19 = nn.Conv2d(128, 128, 1) + if start_layer < 20 <= end_layer: + self.layer20 = nn.BatchNorm2d(128) + if start_layer < 21 <= end_layer: + self.layer21 = nn.ReLU() + if start_layer < 22 <= end_layer: + self.layer22 = nn.Conv2d(128, 128, 3, 2, 1) + if start_layer < 23 <= end_layer: + self.layer23 = nn.BatchNorm2d(128) + if start_layer < 24 <= end_layer: + self.layer24 = nn.ReLU() + if start_layer < 25 <= end_layer: + self.layer25 = nn.Conv2d(128, 256, 1) + if start_layer < 26 <= end_layer: + self.layer26 = nn.BatchNorm2d(256) + if start_layer < 27 <= end_layer: + self.layer27 = nn.ReLU() + if start_layer < 28 <= end_layer: + self.layer28 = nn.Conv2d(256, 256, 3, 1, 1) + if start_layer < 29 <= end_layer: + self.layer29 = nn.BatchNorm2d(256) + if start_layer < 30 <= end_layer: + self.layer30 = nn.ReLU() + if start_layer < 31 <= end_layer: + self.layer31 = nn.Conv2d(256, 256, 1) + if start_layer < 32 <= end_layer: + self.layer32 = nn.BatchNorm2d(256) + if start_layer < 33 <= end_layer: + self.layer33 = nn.ReLU() + if start_layer < 34 <= end_layer: + self.layer34 = nn.Conv2d(256, 256, 3, 2, 1) + if start_layer < 35 <= end_layer: + self.layer35 = nn.BatchNorm2d(256) + if start_layer < 36 <= end_layer: + self.layer36 = nn.ReLU() + if start_layer < 37 <= end_layer: + self.layer37 = nn.Conv2d(256, 512, 1) + if start_layer < 38 <= end_layer: + self.layer38 = nn.BatchNorm2d(512) + if start_layer < 39 <= end_layer: + self.layer39 = nn.ReLU() + if start_layer < 40 <= end_layer: + self.layer40 = nn.Conv2d(512, 512, 3, 1, 1) + if start_layer < 41 <= end_layer: + self.layer41 = nn.BatchNorm2d(512) + if start_layer < 42 <= end_layer: + self.layer42 = nn.ReLU() + if start_layer < 43 <= end_layer: + self.layer43 = nn.Conv2d(512, 512, 1) + if start_layer < 44 <= end_layer: + self.layer44 = nn.BatchNorm2d(512) + if start_layer < 45 <= end_layer: + self.layer45 = nn.ReLU() + if start_layer < 46 <= end_layer: + self.layer46 = nn.Conv2d(512, 512, 3, 1, 1) + if start_layer < 47 <= end_layer: + self.layer47 = nn.BatchNorm2d(512) + if start_layer < 48 <= end_layer: + self.layer48 = nn.ReLU() + if start_layer < 49 <= end_layer: + self.layer49 = nn.Conv2d(512, 512, 1) + if start_layer < 50 <= end_layer: + self.layer50 = nn.BatchNorm2d(512) + if start_layer < 51 <= end_layer: + self.layer51 = nn.ReLU() + if start_layer < 52 <= end_layer: + self.layer52 = nn.Conv2d(512, 512, 3, 1, 1) + if start_layer < 53 <= end_layer: + self.layer53 = nn.BatchNorm2d(512) + if start_layer < 54 <= end_layer: + self.layer54 = nn.ReLU() + if start_layer < 55 <= end_layer: + self.layer55 = nn.Conv2d(512, 512, 1) + if start_layer < 56 <= end_layer: + self.layer56 = nn.BatchNorm2d(512) + if start_layer < 57 <= end_layer: + self.layer57 = nn.ReLU() + if start_layer < 58 <= end_layer: + self.layer58 = nn.Conv2d(512, 512, 3, 1, 1) + if start_layer < 59 <= end_layer: + self.layer59 = nn.BatchNorm2d(512) + if start_layer < 60 <= end_layer: + self.layer60 = nn.ReLU() + if start_layer < 61 <= end_layer: + self.layer61 = nn.Conv2d(512, 512, 1) + if start_layer < 62 <= end_layer: + self.layer62 = nn.BatchNorm2d(512) + if start_layer < 63 <= end_layer: + self.layer63 = nn.ReLU() + if start_layer < 64 <= end_layer: + self.layer64 = nn.Conv2d(512, 512, 3, 1, 1) + if start_layer < 65 <= end_layer: + self.layer65 = nn.BatchNorm2d(512) + if start_layer < 66 <= end_layer: + self.layer66 = nn.ReLU() + if start_layer < 67 <= end_layer: + self.layer67 = nn.Conv2d(512, 512, 1) + if start_layer < 68 <= end_layer: + self.layer68 = nn.BatchNorm2d(512) + if start_layer < 69 <= end_layer: + self.layer69 = nn.ReLU() + if start_layer < 70 <= end_layer: + self.layer70 = nn.Conv2d(512, 512, 3, 2, 1) + if start_layer < 71 <= end_layer: + self.layer71 = nn.BatchNorm2d(512) + if start_layer < 72 <= end_layer: + self.layer72 = nn.ReLU() + if start_layer < 73 <= end_layer: + self.layer73 = nn.Conv2d(512, 1024, 1) + if start_layer < 74 <= end_layer: + self.layer74 = nn.BatchNorm2d(1024) + if start_layer < 75 <= end_layer: + self.layer75 = nn.ReLU() + if start_layer < 76 <= end_layer: + self.layer76 = nn.Conv2d(1024, 1024, 3, 1, 1) + if start_layer < 77 <= end_layer: + self.layer77 = nn.BatchNorm2d(1024) + if start_layer < 78 <= end_layer: + self.layer78 = nn.ReLU() + if start_layer < 79 <= end_layer: + self.layer79 = nn.Conv2d(1024, 1024, 1) + if start_layer < 80 <= end_layer: + self.layer80 = nn.BatchNorm2d(1024) + if start_layer < 81 <= end_layer: + self.layer81 = nn.ReLU() + if start_layer < 82 <= end_layer: + self.layer82 = nn.MaxPool2d(2, 2) + if start_layer < 83 <= end_layer: + self.layer83 = nn.Flatten(1, -1) + if start_layer < 84 <= end_layer: + self.layer84 = nn.Linear(1024, 10) + + def forward(self, x): + for i in range(1, 85): + if self.start_layer < i <= self.end_layer: + layer = getattr(self, f'layer{i}', None) + if layer is not None: + x = layer(x) + return x diff --git a/other/2LS/src/model/MobileNetv1_MNIST.py b/other/2LS/src/model/MobileNetv1_MNIST.py new file mode 100644 index 0000000..880a6a3 --- /dev/null +++ b/other/2LS/src/model/MobileNetv1_MNIST.py @@ -0,0 +1,185 @@ +import torch +import torch.nn as nn + +class MobileNetv1_MNIST(nn.Module): + def __init__(self, start_layer=0, end_layer=84): + super(MobileNetv1_MNIST, self).__init__() + self.start_layer = start_layer + self.end_layer = end_layer + + if start_layer < 1 <= end_layer: + self.layer1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1) + if start_layer < 2 <= end_layer: + self.layer2 = nn.BatchNorm2d(32) + if start_layer < 3 <= end_layer: + self.layer3 = nn.ReLU() + if start_layer < 4 <= end_layer: + self.layer4 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) + if start_layer < 5 <= end_layer: + self.layer5 = nn.BatchNorm2d(32) + if start_layer < 6 <= end_layer: + self.layer6 = nn.ReLU() + if start_layer < 7 <= end_layer: + self.layer7 = nn.Conv2d(32, 64, kernel_size=1, stride=1) + if start_layer < 8 <= end_layer: + self.layer8 = nn.BatchNorm2d(64) + if start_layer < 9 <= end_layer: + self.layer9 = nn.ReLU() + if start_layer < 10 <= end_layer: + self.layer10 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1) + if start_layer < 11 <= end_layer: + self.layer11 = nn.BatchNorm2d(64) + if start_layer < 12 <= end_layer: + self.layer12 = nn.ReLU() + if start_layer < 13 <= end_layer: + self.layer13 = nn.Conv2d(64, 128, kernel_size=1, stride=1) + if start_layer < 14 <= end_layer: + self.layer14 = nn.BatchNorm2d(128) + if start_layer < 15 <= end_layer: + self.layer15 = nn.ReLU() + if start_layer < 16 <= end_layer: + self.layer16 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + if start_layer < 17 <= end_layer: + self.layer17 = nn.BatchNorm2d(128) + if start_layer < 18 <= end_layer: + self.layer18 = nn.ReLU() + if start_layer < 19 <= end_layer: + self.layer19 = nn.Conv2d(128, 128, kernel_size=1, stride=1) + if start_layer < 20 <= end_layer: + self.layer20 = nn.BatchNorm2d(128) + if start_layer < 21 <= end_layer: + self.layer21 = nn.ReLU() + if start_layer < 22 <= end_layer: + self.layer22 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1) + if start_layer < 23 <= end_layer: + self.layer23 = nn.BatchNorm2d(128) + if start_layer < 24 <= end_layer: + self.layer24 = nn.ReLU() + if start_layer < 25 <= end_layer: + self.layer25 = nn.Conv2d(128, 256, kernel_size=1, stride=1) + if start_layer < 26 <= end_layer: + self.layer26 = nn.BatchNorm2d(256) + if start_layer < 27 <= end_layer: + self.layer27 = nn.ReLU() + if start_layer < 28 <= end_layer: + self.layer28 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + if start_layer < 29 <= end_layer: + self.layer29 = nn.BatchNorm2d(256) + if start_layer < 30 <= end_layer: + self.layer30 = nn.ReLU() + if start_layer < 31 <= end_layer: + self.layer31 = nn.Conv2d(256, 256, kernel_size=1, stride=1) + if start_layer < 32 <= end_layer: + self.layer32 = nn.BatchNorm2d(256) + if start_layer < 33 <= end_layer: + self.layer33 = nn.ReLU() + if start_layer < 34 <= end_layer: + self.layer34 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1) + if start_layer < 35 <= end_layer: + self.layer35 = nn.BatchNorm2d(256) + if start_layer < 36 <= end_layer: + self.layer36 = nn.ReLU() + if start_layer < 37 <= end_layer: + self.layer37 = nn.Conv2d(256, 512, kernel_size=1, stride=1) + if start_layer < 38 <= end_layer: + self.layer38 = nn.BatchNorm2d(512) + if start_layer < 39 <= end_layer: + self.layer39 = nn.ReLU() + if start_layer < 40 <= end_layer: + self.layer40 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 41 <= end_layer: + self.layer41 = nn.BatchNorm2d(512) + if start_layer < 42 <= end_layer: + self.layer42 = nn.ReLU() + if start_layer < 43 <= end_layer: + self.layer43 = nn.Conv2d(512, 512, kernel_size=1, stride=1) + if start_layer < 44 <= end_layer: + self.layer44 = nn.BatchNorm2d(512) + if start_layer < 45 <= end_layer: + self.layer45 = nn.ReLU() + if start_layer < 46 <= end_layer: + self.layer46 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 47 <= end_layer: + self.layer47 = nn.BatchNorm2d(512) + if start_layer < 48 <= end_layer: + self.layer48 = nn.ReLU() + if start_layer < 49 <= end_layer: + self.layer49 = nn.Conv2d(512, 512, kernel_size=1, stride=1) + if start_layer < 50 <= end_layer: + self.layer50 = nn.BatchNorm2d(512) + if start_layer < 51 <= end_layer: + self.layer51 = nn.ReLU() + if start_layer < 52 <= end_layer: + self.layer52 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 53 <= end_layer: + self.layer53 = nn.BatchNorm2d(512) + if start_layer < 54 <= end_layer: + self.layer54 = nn.ReLU() + if start_layer < 55 <= end_layer: + self.layer55 = nn.Conv2d(512, 512, kernel_size=1, stride=1) + if start_layer < 56 <= end_layer: + self.layer56 = nn.BatchNorm2d(512) + if start_layer < 57 <= end_layer: + self.layer57 = nn.ReLU() + if start_layer < 58 <= end_layer: + self.layer58 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 59 <= end_layer: + self.layer59 = nn.BatchNorm2d(512) + if start_layer < 60 <= end_layer: + self.layer60 = nn.ReLU() + if start_layer < 61 <= end_layer: + self.layer61 = nn.Conv2d(512, 512, kernel_size=1, stride=1) + if start_layer < 62 <= end_layer: + self.layer62 = nn.BatchNorm2d(512) + if start_layer < 63 <= end_layer: + self.layer63 = nn.ReLU() + if start_layer < 64 <= end_layer: + self.layer64 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 65 <= end_layer: + self.layer65 = nn.BatchNorm2d(512) + if start_layer < 66 <= end_layer: + self.layer66 = nn.ReLU() + if start_layer < 67 <= end_layer: + self.layer67 = nn.Conv2d(512, 512, kernel_size=1, stride=1) + if start_layer < 68 <= end_layer: + self.layer68 = nn.BatchNorm2d(512) + if start_layer < 69 <= end_layer: + self.layer69 = nn.ReLU() + if start_layer < 70 <= end_layer: + self.layer70 = nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1) + if start_layer < 71 <= end_layer: + self.layer71 = nn.BatchNorm2d(512) + if start_layer < 72 <= end_layer: + self.layer72 = nn.ReLU() + if start_layer < 73 <= end_layer: + self.layer73 = nn.Conv2d(512, 1024, kernel_size=1, stride=1) + if start_layer < 74 <= end_layer: + self.layer74 = nn.BatchNorm2d(1024) + if start_layer < 75 <= end_layer: + self.layer75 = nn.ReLU() + if start_layer < 76 <= end_layer: + self.layer76 = nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1) + if start_layer < 77 <= end_layer: + self.layer77 = nn.BatchNorm2d(1024) + if start_layer < 78 <= end_layer: + self.layer78 = nn.ReLU() + if start_layer < 79 <= end_layer: + self.layer79 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1) + if start_layer < 80 <= end_layer: + self.layer80 = nn.BatchNorm2d(1024) + if start_layer < 81 <= end_layer: + self.layer81 = nn.ReLU() + if start_layer < 82 <= end_layer: + self.layer82 = nn.MaxPool2d(2, 2) + if start_layer < 83 <= end_layer: + self.layer83 = nn.Flatten(1, -1) + if start_layer < 84 <= end_layer: + self.layer84 = nn.Linear(1024, 10) + + def forward(self, x): + for i in range(1, 85): + if self.start_layer < i <= self.end_layer: + layer = getattr(self, f'layer{i}', None) + if layer is not None: + x = layer(x) + return x diff --git a/other/2LS/src/model/VGG16_CIFAR10.py b/other/2LS/src/model/VGG16_CIFAR10.py new file mode 100644 index 0000000..e84652c --- /dev/null +++ b/other/2LS/src/model/VGG16_CIFAR10.py @@ -0,0 +1,230 @@ +import torch.nn as nn + +class VGG16_CIFAR10(nn.Module): + def __init__(self, start_layer=0, end_layer=52): + super(VGG16_CIFAR10, self).__init__() + self.start_layer = start_layer + self.end_layer = end_layer + + if start_layer < 1 <= end_layer: + self.layer1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + if start_layer < 2 <= end_layer: + self.layer2 = nn.BatchNorm2d(64) + if start_layer < 3 <= end_layer: + self.layer3 = nn.ReLU() + if start_layer < 4 <= end_layer: + self.layer4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + if start_layer < 5 <= end_layer: + self.layer5 = nn.BatchNorm2d(64) + if start_layer < 6 <= end_layer: + self.layer6 = nn.ReLU() + if start_layer < 7 <= end_layer: + self.layer7 = nn.MaxPool2d(kernel_size=2, stride=2) + + if start_layer < 8 <= end_layer: + self.layer8 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + if start_layer < 9 <= end_layer: + self.layer9 = nn.BatchNorm2d(128) + if start_layer < 10 <= end_layer: + self.layer10 = nn.ReLU() + if start_layer < 11 <= end_layer: + self.layer11 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + if start_layer < 12 <= end_layer: + self.layer12 = nn.BatchNorm2d(128) + if start_layer < 13 <= end_layer: + self.layer13 = nn.ReLU() + if start_layer < 14 <= end_layer: + self.layer14 = nn.MaxPool2d(kernel_size=2, stride=2) + + if start_layer < 15 <= end_layer: + self.layer15 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) + if start_layer < 16 <= end_layer: + self.layer16 = nn.BatchNorm2d(256) + if start_layer < 17 <= end_layer: + self.layer17 = nn.ReLU() + if start_layer < 18 <= end_layer: + self.layer18 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + if start_layer < 19 <= end_layer: + self.layer19 = nn.BatchNorm2d(256) + if start_layer < 20 <= end_layer: + self.layer20 = nn.ReLU() + if start_layer < 21 <= end_layer: + self.layer21 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + if start_layer < 22 <= end_layer: + self.layer22 = nn.BatchNorm2d(256) + if start_layer < 23 <= end_layer: + self.layer23 = nn.ReLU() + if start_layer < 24 <= end_layer: + self.layer24 = nn.MaxPool2d(kernel_size=2, stride=2) + + if start_layer < 25 <= end_layer: + self.layer25 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 26 <= end_layer: + self.layer26 = nn.BatchNorm2d(512) + if start_layer < 27 <= end_layer: + self.layer27 = nn.ReLU() + if start_layer < 28 <= end_layer: + self.layer28 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 29 <= end_layer: + self.layer29 = nn.BatchNorm2d(512) + if start_layer < 30 <= end_layer: + self.layer30 = nn.ReLU() + if start_layer < 31 <= end_layer: + self.layer31 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 32 <= end_layer: + self.layer32 = nn.BatchNorm2d(512) + if start_layer < 33 <= end_layer: + self.layer33 = nn.ReLU() + if start_layer < 34 <= end_layer: + self.layer34 = nn.MaxPool2d(kernel_size=2, stride=2) + + if start_layer < 35 <= end_layer: + self.layer35 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 36 <= end_layer: + self.layer36 = nn.BatchNorm2d(512) + if start_layer < 37 <= end_layer: + self.layer37 = nn.ReLU() + if start_layer < 38 <= end_layer: + self.layer38 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 39 <= end_layer: + self.layer39 = nn.BatchNorm2d(512) + if start_layer < 40 <= end_layer: + self.layer40 = nn.ReLU() + if start_layer < 41 <= end_layer: + self.layer41 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 42 <= end_layer: + self.layer42 = nn.BatchNorm2d(512) + if start_layer < 43 <= end_layer: + self.layer43 = nn.ReLU() + if start_layer < 44 <= end_layer: + self.layer44 = nn.MaxPool2d(kernel_size=2, stride=2) + + if start_layer < 45 <= end_layer: + self.layer45 = nn.Flatten(1, -1) + if start_layer < 46 <= end_layer: + self.layer46 = nn.Dropout(0.5) + if start_layer < 47 <= end_layer: + self.layer47 = nn.Linear(512 * 1 * 1, 4096) + if start_layer < 48 <= end_layer: + self.layer48 = nn.ReLU() + if start_layer < 49 <= end_layer: + self.layer49 = nn.Dropout(0.5) + if start_layer < 50 <= end_layer: + self.layer50 = nn.Linear(4096, 4096) + if start_layer < 51 <= end_layer: + self.layer51 = nn.ReLU() + if start_layer < 52 <= end_layer: + self.layer52 = nn.Linear(4096, 10) + + def forward(self, x): + if self.start_layer < 1 <= self.end_layer: + x = self.layer1(x) + if self.start_layer < 2 <= self.end_layer: + x = self.layer2(x) + if self.start_layer < 3 <= self.end_layer: + x = self.layer3(x) + if self.start_layer < 4 <= self.end_layer: + x = self.layer4(x) + if self.start_layer < 5 <= self.end_layer: + x = self.layer5(x) + if self.start_layer < 6 <= self.end_layer: + x = self.layer6(x) + if self.start_layer < 7 <= self.end_layer: + x = self.layer7(x) + + if self.start_layer < 8 <= self.end_layer: + x = self.layer8(x) + if self.start_layer < 9 <= self.end_layer: + x = self.layer9(x) + if self.start_layer < 10 <= self.end_layer: + x = self.layer10(x) + if self.start_layer < 11 <= self.end_layer: + x = self.layer11(x) + if self.start_layer < 12 <= self.end_layer: + x = self.layer12(x) + if self.start_layer < 13 <= self.end_layer: + x = self.layer13(x) + if self.start_layer < 14 <= self.end_layer: + x = self.layer14(x) + + if self.start_layer < 15 <= self.end_layer: + x = self.layer15(x) + if self.start_layer < 16 <= self.end_layer: + x = self.layer16(x) + if self.start_layer < 17 <= self.end_layer: + x = self.layer17(x) + if self.start_layer < 18 <= self.end_layer: + x = self.layer18(x) + if self.start_layer < 19 <= self.end_layer: + x = self.layer19(x) + if self.start_layer < 20 <= self.end_layer: + x = self.layer20(x) + if self.start_layer < 21 <= self.end_layer: + x = self.layer21(x) + if self.start_layer < 22 <= self.end_layer: + x = self.layer22(x) + if self.start_layer < 23 <= self.end_layer: + x = self.layer23(x) + if self.start_layer < 24 <= self.end_layer: + x = self.layer24(x) + + if self.start_layer < 25 <= self.end_layer: + x = self.layer25(x) + if self.start_layer < 26 <= self.end_layer: + x = self.layer26(x) + if self.start_layer < 27 <= self.end_layer: + x = self.layer27(x) + if self.start_layer < 28 <= self.end_layer: + x = self.layer28(x) + if self.start_layer < 29 <= self.end_layer: + x = self.layer29(x) + if self.start_layer < 30 <= self.end_layer: + x = self.layer30(x) + if self.start_layer < 31 <= self.end_layer: + x = self.layer31(x) + if self.start_layer < 32 <= self.end_layer: + x = self.layer32(x) + if self.start_layer < 33 <= self.end_layer: + x = self.layer33(x) + if self.start_layer < 34 <= self.end_layer: + x = self.layer34(x) + + if self.start_layer < 35 <= self.end_layer: + x = self.layer35(x) + if self.start_layer < 36 <= self.end_layer: + x = self.layer36(x) + if self.start_layer < 37 <= self.end_layer: + x = self.layer37(x) + if self.start_layer < 38 <= self.end_layer: + x = self.layer38(x) + if self.start_layer < 39 <= self.end_layer: + x = self.layer39(x) + if self.start_layer < 40 <= self.end_layer: + x = self.layer40(x) + if self.start_layer < 41 <= self.end_layer: + x = self.layer41(x) + if self.start_layer < 42 <= self.end_layer: + x = self.layer42(x) + if self.start_layer < 43 <= self.end_layer: + x = self.layer43(x) + if self.start_layer < 44 <= self.end_layer: + x = self.layer44(x) + + if self.start_layer < 45 <= self.end_layer: + x = self.layer45(x) + if self.start_layer < 46 <= self.end_layer: + x = self.layer46(x) + if self.start_layer < 47 <= self.end_layer: + x = self.layer47(x) + if self.start_layer < 48 <= self.end_layer: + x = self.layer48(x) + if self.start_layer < 49 <= self.end_layer: + x = self.layer49(x) + if self.start_layer < 50 <= self.end_layer: + x = self.layer50(x) + if self.start_layer < 51 <= self.end_layer: + x = self.layer51(x) + if self.start_layer < 52 <= self.end_layer: + x = self.layer52(x) + + return x diff --git a/other/2LS/src/model/VGG16_MNIST.py b/other/2LS/src/model/VGG16_MNIST.py new file mode 100644 index 0000000..1cc070e --- /dev/null +++ b/other/2LS/src/model/VGG16_MNIST.py @@ -0,0 +1,226 @@ +import torch.nn as nn + +class VGG16_MNIST(nn.Module): + def __init__(self, start_layer=0, end_layer=51): + super(VGG16_MNIST, self).__init__() + self.start_layer = start_layer + self.end_layer = end_layer + + if start_layer < 1 <= end_layer: + self.layer1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1) + if start_layer < 2 <= end_layer: + self.layer2 = nn.BatchNorm2d(64) + if start_layer < 3 <= end_layer: + self.layer3 = nn.ReLU() + if start_layer < 4 <= end_layer: + self.layer4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + if start_layer < 5 <= end_layer: + self.layer5 = nn.BatchNorm2d(64) + if start_layer < 6 <= end_layer: + self.layer6 = nn.ReLU() + if start_layer < 7 <= end_layer: + self.layer7 = nn.MaxPool2d(kernel_size=2, stride=2) + + if start_layer < 8 <= end_layer: + self.layer8 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + if start_layer < 9 <= end_layer: + self.layer9 = nn.BatchNorm2d(128) + if start_layer < 10 <= end_layer: + self.layer10 = nn.ReLU() + if start_layer < 11 <= end_layer: + self.layer11 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + if start_layer < 12 <= end_layer: + self.layer12 = nn.BatchNorm2d(128) + if start_layer < 13 <= end_layer: + self.layer13 = nn.ReLU() + if start_layer < 14 <= end_layer: + self.layer14 = nn.MaxPool2d(kernel_size=2, stride=2) + + if start_layer < 15 <= end_layer: + self.layer15 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) + if start_layer < 16 <= end_layer: + self.layer16 = nn.BatchNorm2d(256) + if start_layer < 17 <= end_layer: + self.layer17 = nn.ReLU() + if start_layer < 18 <= end_layer: + self.layer18 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + if start_layer < 19 <= end_layer: + self.layer19 = nn.BatchNorm2d(256) + if start_layer < 20 <= end_layer: + self.layer20 = nn.ReLU() + if start_layer < 21 <= end_layer: + self.layer21 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + if start_layer < 22 <= end_layer: + self.layer22 = nn.BatchNorm2d(256) + if start_layer < 23 <= end_layer: + self.layer23 = nn.ReLU() + if start_layer < 24 <= end_layer: + self.layer24 = nn.MaxPool2d(kernel_size=2, stride=2) + + if start_layer < 25 <= end_layer: + self.layer25 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 26 <= end_layer: + self.layer26 = nn.BatchNorm2d(512) + if start_layer < 27 <= end_layer: + self.layer27 = nn.ReLU() + if start_layer < 28 <= end_layer: + self.layer28 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 29 <= end_layer: + self.layer29 = nn.BatchNorm2d(512) + if start_layer < 30 <= end_layer: + self.layer30 = nn.ReLU() + if start_layer < 31 <= end_layer: + self.layer31 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 32 <= end_layer: + self.layer32 = nn.BatchNorm2d(512) + if start_layer < 33 <= end_layer: + self.layer33 = nn.ReLU() + if start_layer < 34 <= end_layer: + self.layer34 = nn.MaxPool2d(kernel_size=2, stride=2) + + if start_layer < 35 <= end_layer: + self.layer35 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 36 <= end_layer: + self.layer36 = nn.BatchNorm2d(512) + if start_layer < 37 <= end_layer: + self.layer37 = nn.ReLU() + if start_layer < 38 <= end_layer: + self.layer38 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 39 <= end_layer: + self.layer39 = nn.BatchNorm2d(512) + if start_layer < 40 <= end_layer: + self.layer40 = nn.ReLU() + if start_layer < 41 <= end_layer: + self.layer41 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + if start_layer < 42 <= end_layer: + self.layer42 = nn.BatchNorm2d(512) + if start_layer < 43 <= end_layer: + self.layer43 = nn.ReLU() + + if start_layer < 44 <= end_layer: + self.layer44 = nn.Flatten(1, -1) + if start_layer < 45 <= end_layer: + self.layer45 = nn.Dropout(0.5) + if start_layer < 46 <= end_layer: + self.layer46 = nn.Linear(512, 4096) + if start_layer < 47 <= end_layer: + self.layer47 = nn.ReLU() + if start_layer < 48 <= end_layer: + self.layer48 = nn.Dropout(0.5) + if start_layer < 49 <= end_layer: + self.layer49 = nn.Linear(4096, 4096) + if start_layer < 50 <= end_layer: + self.layer50 = nn.ReLU() + if start_layer < 51 <= end_layer: + self.layer51 = nn.Linear(4096, 10) + + def forward(self, x): + if self.start_layer < 1 <= self.end_layer: + x = self.layer1(x) + if self.start_layer < 2 <= self.end_layer: + x = self.layer2(x) + if self.start_layer < 3 <= self.end_layer: + x = self.layer3(x) + if self.start_layer < 4 <= self.end_layer: + x = self.layer4(x) + if self.start_layer < 5 <= self.end_layer: + x = self.layer5(x) + if self.start_layer < 6 <= self.end_layer: + x = self.layer6(x) + if self.start_layer < 7 <= self.end_layer: + x = self.layer7(x) + + if self.start_layer < 8 <= self.end_layer: + x = self.layer8(x) + if self.start_layer < 9 <= self.end_layer: + x = self.layer9(x) + if self.start_layer < 10 <= self.end_layer: + x = self.layer10(x) + if self.start_layer < 11 <= self.end_layer: + x = self.layer11(x) + if self.start_layer < 12 <= self.end_layer: + x = self.layer12(x) + if self.start_layer < 13 <= self.end_layer: + x = self.layer13(x) + if self.start_layer < 14 <= self.end_layer: + x = self.layer14(x) + + if self.start_layer < 15 <= self.end_layer: + x = self.layer15(x) + if self.start_layer < 16 <= self.end_layer: + x = self.layer16(x) + if self.start_layer < 17 <= self.end_layer: + x = self.layer17(x) + if self.start_layer < 18 <= self.end_layer: + x = self.layer18(x) + if self.start_layer < 19 <= self.end_layer: + x = self.layer19(x) + if self.start_layer < 20 <= self.end_layer: + x = self.layer20(x) + if self.start_layer < 21 <= self.end_layer: + x = self.layer21(x) + if self.start_layer < 22 <= self.end_layer: + x = self.layer22(x) + if self.start_layer < 23 <= self.end_layer: + x = self.layer23(x) + if self.start_layer < 24 <= self.end_layer: + x = self.layer24(x) + + if self.start_layer < 25 <= self.end_layer: + x = self.layer25(x) + if self.start_layer < 26 <= self.end_layer: + x = self.layer26(x) + if self.start_layer < 27 <= self.end_layer: + x = self.layer27(x) + if self.start_layer < 28 <= self.end_layer: + x = self.layer28(x) + if self.start_layer < 29 <= self.end_layer: + x = self.layer29(x) + if self.start_layer < 30 <= self.end_layer: + x = self.layer30(x) + if self.start_layer < 31 <= self.end_layer: + x = self.layer31(x) + if self.start_layer < 32 <= self.end_layer: + x = self.layer32(x) + if self.start_layer < 33 <= self.end_layer: + x = self.layer33(x) + if self.start_layer < 34 <= self.end_layer: + x = self.layer34(x) + + if self.start_layer < 35 <= self.end_layer: + x = self.layer35(x) + if self.start_layer < 36 <= self.end_layer: + x = self.layer36(x) + if self.start_layer < 37 <= self.end_layer: + x = self.layer37(x) + if self.start_layer < 38 <= self.end_layer: + x = self.layer38(x) + if self.start_layer < 39 <= self.end_layer: + x = self.layer39(x) + if self.start_layer < 40 <= self.end_layer: + x = self.layer40(x) + if self.start_layer < 41 <= self.end_layer: + x = self.layer41(x) + if self.start_layer < 42 <= self.end_layer: + x = self.layer42(x) + if self.start_layer < 43 <= self.end_layer: + x = self.layer43(x) + + if self.start_layer < 44 <= self.end_layer: + x = self.layer44(x) + if self.start_layer < 45 <= self.end_layer: + x = self.layer45(x) + if self.start_layer < 46 <= self.end_layer: + x = self.layer46(x) + if self.start_layer < 47 <= self.end_layer: + x = self.layer47(x) + if self.start_layer < 48 <= self.end_layer: + x = self.layer48(x) + if self.start_layer < 49 <= self.end_layer: + x = self.layer49(x) + if self.start_layer < 50 <= self.end_layer: + x = self.layer50(x) + if self.start_layer < 51 <= self.end_layer: + x = self.layer51(x) + + return x diff --git a/other/2LS/src/model/ViT_CIFAR10.py b/other/2LS/src/model/ViT_CIFAR10.py new file mode 100644 index 0000000..098e206 --- /dev/null +++ b/other/2LS/src/model/ViT_CIFAR10.py @@ -0,0 +1,116 @@ +import torch +import torch.nn as nn +class TransformerEncoderBlock(nn.Module): + def __init__(self, embed_dim, num_heads=4, mlp_dim=256): + super().__init__() + self.ln1 = nn.LayerNorm(embed_dim) + self.mha = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) + self.ln2 = nn.LayerNorm(embed_dim) + self.mlp = nn.Sequential( + nn.Linear(embed_dim, mlp_dim), + nn.GELU(), + nn.Linear(mlp_dim, embed_dim) + ) + + def forward(self, x): + # Attention + residual + _x = self.ln1(x) + x_attn = self.mha(_x, _x, _x)[0] + x = x + x_attn + + # MLP + residual + x_mlp = self.mlp(self.ln2(x)) + x = x + x_mlp + return x + + +class ViT_CIFAR10(nn.Module): + + def __init__(self, start_layer=0, end_layer=12): + super(ViT_CIFAR10, self).__init__() + self.start_layer = start_layer + self.end_layer = end_layer + + img_size = 32 + patch_size = 4 + in_channels = 3 + embed_dim = 128 + num_classes = 10 + num_patches = (img_size // patch_size) ** 2 + + if (self.start_layer < 1) and (self.end_layer >= 1): + self.layer1 = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + + if (self.start_layer < 2) and (self.end_layer >= 2): + self.layer2 = nn.Flatten(2) + + if (self.start_layer < 3) and (self.end_layer >= 3): + self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) + + if (self.start_layer < 4) and (self.end_layer >= 4): + self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim)) + self.layer4 = nn.Identity() + + if (self.start_layer < 5) and self.end_layer >= 5: + self.layer5 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 6) and (self.end_layer >= 6): + self.layer6 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 7) and (self.end_layer >= 7): + self.layer7 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 8) and (self.end_layer >= 8): + self.layer8 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 9) and (self.end_layer >= 9): + self.layer9 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 10) and (self.end_layer >= 10): + self.layer10 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 11) and (self.end_layer >= 11): + self.layer11 = nn.LayerNorm(embed_dim) + + if (self.start_layer < 12) and (self.end_layer >= 12): + self.layer12 = nn.Linear(embed_dim, num_classes) + + def forward(self, x): + if self.start_layer < 1 <= self.end_layer: + x = self.layer1(x) + + if self.start_layer < 2 <= self.end_layer: + x = self.layer2(x) + x = x.transpose(1, 2) + + if self.start_layer < 3 <= self.end_layer: + cls_token = self.cls_token.expand(x.size(0), -1, -1) + x = torch.cat([cls_token, x], dim=1) + + if self.start_layer < 4 <= self.end_layer: + x = self.layer4(x + self.pos_embed) + + if self.start_layer < 5 <= self.end_layer: + x = self.layer5(x) + + if self.start_layer < 6 <= self.end_layer: + x = self.layer6(x) + + if self.start_layer < 7 <= self.end_layer: + x = self.layer7(x) + + if self.start_layer < 8 <= self.end_layer: + x = self.layer8(x) + + if self.start_layer < 9 <= self.end_layer: + x = self.layer9(x) + + if self.start_layer < 10 <= self.end_layer: + x = self.layer10(x) + + if self.start_layer < 11 <= self.end_layer: + x = self.layer11(x[:, 0]) + + if self.start_layer < 12 <= self.end_layer: + x = self.layer12(x) + return x diff --git a/other/2LS/src/model/ViT_MNIST.py b/other/2LS/src/model/ViT_MNIST.py new file mode 100644 index 0000000..5d21645 --- /dev/null +++ b/other/2LS/src/model/ViT_MNIST.py @@ -0,0 +1,116 @@ +import torch +import torch.nn as nn +class TransformerEncoderBlock(nn.Module): + def __init__(self, embed_dim, num_heads=4, mlp_dim=256): + super().__init__() + self.ln1 = nn.LayerNorm(embed_dim) + self.mha = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) + self.ln2 = nn.LayerNorm(embed_dim) + self.mlp = nn.Sequential( + nn.Linear(embed_dim, mlp_dim), + nn.GELU(), + nn.Linear(mlp_dim, embed_dim) + ) + + def forward(self, x): + # Attention + residual + _x = self.ln1(x) + x_attn = self.mha(_x, _x, _x)[0] + x = x + x_attn + + # MLP + residual + x_mlp = self.mlp(self.ln2(x)) + x = x + x_mlp + return x + + +class ViT_MNIST(nn.Module): + + def __init__(self, start_layer=0, end_layer=12): + super(ViT_MNIST, self).__init__() + self.start_layer = start_layer + self.end_layer = end_layer + + img_size = 28 + patch_size = 4 + in_channels = 1 + embed_dim = 128 + num_classes = 10 + num_patches = (img_size // patch_size) ** 2 + + if (self.start_layer < 1) and (self.end_layer >= 1): + self.layer1 = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + + if (self.start_layer < 2) and (self.end_layer >= 2): + self.layer2 = nn.Flatten(2) + + if (self.start_layer < 3) and (self.end_layer >= 3): + self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) + + if (self.start_layer < 4) and (self.end_layer >= 4): + self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim)) + self.layer4 = nn.Identity() + + if (self.start_layer < 5) and self.end_layer >= 5: + self.layer5 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 6) and (self.end_layer >= 6): + self.layer6 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 7) and (self.end_layer >= 7): + self.layer7 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 8) and (self.end_layer >= 8): + self.layer8 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 9) and (self.end_layer >= 9): + self.layer9 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 10) and (self.end_layer >= 10): + self.layer10 = TransformerEncoderBlock(embed_dim=128) + + if (self.start_layer < 11) and (self.end_layer >= 11): + self.layer11 = nn.LayerNorm(embed_dim) + + if (self.start_layer < 12) and (self.end_layer >= 12): + self.layer12 = nn.Linear(embed_dim, num_classes) + + def forward(self, x): + if self.start_layer < 1 <= self.end_layer: + x = self.layer1(x) + + if self.start_layer < 2 <= self.end_layer: + x = self.layer2(x) + x = x.transpose(1, 2) + + if self.start_layer < 3 <= self.end_layer: + cls_token = self.cls_token.expand(x.size(0), -1, -1) + x = torch.cat([cls_token, x], dim=1) + + if self.start_layer < 4 <= self.end_layer: + x = self.layer4(x + self.pos_embed) + + if self.start_layer < 5 <= self.end_layer: + x = self.layer5(x) + + if self.start_layer < 6 <= self.end_layer: + x = self.layer6(x) + + if self.start_layer < 7 <= self.end_layer: + x = self.layer7(x) + + if self.start_layer < 8 <= self.end_layer: + x = self.layer8(x) + + if self.start_layer < 9 <= self.end_layer: + x = self.layer9(x) + + if self.start_layer < 10 <= self.end_layer: + x = self.layer10(x) + + if self.start_layer < 11 <= self.end_layer: + x = self.layer11(x[:, 0]) + + if self.start_layer < 12 <= self.end_layer: + x = self.layer12(x) + return x diff --git a/other/2LS/src/model/__init__.py b/other/2LS/src/model/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/other/2LS/src/model/__init__.py @@ -0,0 +1 @@ + diff --git a/other/2LS/src/train/BERT.py b/other/2LS/src/train/BERT.py new file mode 100644 index 0000000..0817a22 --- /dev/null +++ b/other/2LS/src/train/BERT.py @@ -0,0 +1,170 @@ +import time +import pickle +from tqdm import tqdm + +import torch +import torch.nn as nn + +import src.Log + +class Train_BERT: + def __init__(self, client_id, layer_id, channel, device, in_cluster_id, idx): + self.client_id = client_id + self.layer_id = layer_id + self.channel = channel + self.device = device + self.data_count = 0 + self.size = None + self.in_cluster_id = in_cluster_id + self.idx = idx + + def send_intermediate_output(self, output, labels, trace): + + forward_queue_name = f'intermediate_queue_{self.layer_id}_{self.idx}' + self.channel.queue_declare(forward_queue_name, durable=False) + + if trace: + trace.append(self.client_id) + message = pickle.dumps( + {"data": output.detach().cpu().numpy(), "label": labels.cpu(), "trace": trace} + ) + else: + message = pickle.dumps( + {"data": output.detach().cpu().numpy(), "label": labels.cpu(), "trace": [self.client_id]} + ) + if self.size is None: + self.size = len(message) + print(f'Length message: {self.size} (bytes).') + + self.channel.basic_publish( + exchange='', + routing_key=forward_queue_name, + body=message + ) + + def send_gradient(self, gradient, trace): + to_client_id = trace[-1] + trace.pop(-1) + backward_queue_name = f'gradient_queue_{self.layer_id - 1}_{to_client_id}' + self.channel.queue_declare(queue=backward_queue_name, durable=False) + + message = pickle.dumps( + {"data": gradient.detach().cpu().numpy(), "trace": trace}) + + if self.size is None: + self.size = len(message) + print(f'Length message: {self.size} (bytes).') + self.channel.basic_publish( + exchange='', + routing_key=backward_queue_name, + body=message + ) + + def send_to_server(self, message): + self.channel.queue_declare('rpc_queue', durable=False) + self.channel.basic_publish(exchange='', + routing_key='rpc_queue', + body=pickle.dumps(message)) + + def train_on_first_layer(self, model, learning, train_loader=None): + optimizer = torch.optim.AdamW(model.parameters(), lr=learning["learning-rate"], weight_decay=learning["weight-decay"]) + + backward_queue_name = f'gradient_queue_{self.layer_id}_{self.client_id}' + self.channel.queue_declare(queue=backward_queue_name, durable=False) + self.channel.basic_qos(prefetch_count=1) + model = model.to(self.device) + + for batch in tqdm(train_loader, desc="Fine tuning"): + model.train() + optimizer.zero_grad() + + input_ids = batch['input_ids'].to(self.device) + labels = batch['labels'].to(self.device) + + intermediate_output = model(input_ids=input_ids) + intermediate_output = intermediate_output.detach().requires_grad_(True) + + self.data_count += 1 + + self.send_intermediate_output(intermediate_output, labels, trace=None) + + while True: + method_frame, header_frame, body = self.channel.basic_get(queue=backward_queue_name, auto_ack=True) + if method_frame and body: + received_data = pickle.loads(body) + gradient_numpy = received_data["data"] + gradient = torch.tensor(gradient_numpy).to(self.device) + + output = model(input_ids=input_ids) + output.backward(gradient=gradient) + optimizer.step() + break + + else: + time.sleep(0.5) + + notify_data = {"action": "NOTIFY", "client_id": self.client_id, "layer_id": self.layer_id, + "in_cluster_id": self.in_cluster_id, + "message": "Finish training!"} + + src.Log.print_with_color("[>>>] Finish training!", "red") + self.send_to_server(notify_data) + + while True: + broadcast_queue_name = f'reply_{self.client_id}' + method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + if body: + received_data = pickle.loads(body) + src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue") + if received_data["action"] == "PAUSE": + return True, self.data_count + time.sleep(0.5) + + def train_on_last_layer(self, model, learning): + optimizer = torch.optim.AdamW(model.parameters(), lr=learning["learning-rate"], weight_decay=learning["weight-decay"]) + criterion = nn.CrossEntropyLoss() + result = True + + forward_queue_name = f'intermediate_queue_{self.layer_id - 1}_{self.idx}' + self.channel.queue_declare(queue=forward_queue_name, durable=False) + self.channel.basic_qos(prefetch_count=1) + print('Waiting for intermediate output. To exit press CTRL+C') + model.to(self.device) + model.train() + while True: + method_frame, header_frame, body = self.channel.basic_get(queue=forward_queue_name, auto_ack=True) + if method_frame and body: + optimizer.zero_grad() + received_data = pickle.loads(body) + intermediate_output_numpy = received_data["data"] + trace = received_data["trace"] + labels = received_data["label"].to(self.device) + intermediate_output = torch.tensor(intermediate_output_numpy, requires_grad=True).float().to(self.device) + + output = model(input_ids=intermediate_output) + + loss = criterion(output, labels) + + if torch.isnan(loss).any(): + src.Log.print_with_color("NaN detected in loss", "yellow") + result = False + + print(f"Loss: {loss.item()}") + intermediate_output.retain_grad() + loss.backward() + + optimizer.step() + self.data_count += 1 + + gradient = intermediate_output.grad + self.send_gradient(gradient, trace) + + else: + broadcast_queue_name = f'reply_{self.client_id}' + method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + if body: + received_data = pickle.loads(body) + src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue") + if received_data["action"] == "PAUSE": + return result, self.data_count + time.sleep(0.5) diff --git a/other/2LS/src/train/KWT.py b/other/2LS/src/train/KWT.py new file mode 100644 index 0000000..79695f5 --- /dev/null +++ b/other/2LS/src/train/KWT.py @@ -0,0 +1,168 @@ +import time +import pickle +from tqdm import tqdm + +import torch +import torch.optim as optim +import torch.nn as nn + +import src.Log + +class Train_KWT: + + def __init__(self, client_id, layer_id, channel, device, in_cluster_id, idx): + self.client_id = client_id + self.layer_id = layer_id + self.channel = channel + self.device = device + self.data_count = 0 + self.in_cluster_id = in_cluster_id + self.idx = idx + self.size = None + + def send_intermediate_output(self, output, labels, trace): + forward_queue_name = f'intermediate_queue_{self.layer_id}_{self.idx}' + self.channel.queue_declare(forward_queue_name, durable=False) + + if trace: + trace.append(self.client_id) + message = pickle.dumps( + {"data": output.detach().cpu().numpy(), "label": labels, "trace": trace} + ) + else: + message = pickle.dumps( + {"data": output.detach().cpu().numpy(), "label": labels, "trace": [self.client_id]} + ) + + if self.size is None: + self.size = len(message) + print(f'Length message: {self.size} (bytes).') + + self.channel.basic_publish( + exchange='', + routing_key=forward_queue_name, + body=message + ) + + def send_gradient(self, gradient, trace): + to_client_id = trace[-1] + trace.pop(-1) + backward_queue_name = f'gradient_queue_{self.layer_id - 1}_{to_client_id}' + self.channel.queue_declare(queue=backward_queue_name, durable=False) + + message = pickle.dumps( + {"data": gradient.detach().cpu().numpy(), "trace": trace, "test": False}) + + self.channel.basic_publish( + exchange='', + routing_key=backward_queue_name, + body=message + ) + + def send_to_server(self, message): + self.channel.queue_declare('rpc_queue', durable=False) + self.channel.basic_publish(exchange='', + routing_key='rpc_queue', + body=pickle.dumps(message)) + + def train_on_first_layer(self, model, learning, train_loader=None): + optimizer = optim.AdamW(model.parameters(), lr=learning["learning-rate"], weight_decay=learning["weight-decay"]) + + backward_queue_name = f'gradient_queue_{self.layer_id}_{self.client_id}' + self.channel.queue_declare(queue=backward_queue_name, durable=False) + self.channel.basic_qos(prefetch_count=1) + model.to(self.device) + + for batch in tqdm(train_loader, desc="Training"): + model.train() + optimizer.zero_grad() + training_data, labels = batch + training_data = training_data.to(self.device) + intermediate_output = model(training_data) + intermediate_output = intermediate_output.detach().requires_grad_(True) + + self.data_count += 1 + + self.send_intermediate_output(intermediate_output, labels, trace=None) + + while True: + method_frame, header_frame, body = self.channel.basic_get(queue=backward_queue_name, auto_ack=True) + if method_frame and body: + received_data = pickle.loads(body) + gradient_numpy = received_data["data"] + gradient = torch.tensor(gradient_numpy).to(self.device) + + output = model(training_data) + output.backward(gradient=gradient) + optimizer.step() + break + + else: + time.sleep(0.5) + + notify_data = {"action": "NOTIFY", "client_id": self.client_id, "layer_id": self.layer_id, + "in_cluster_id": self.in_cluster_id, + "message": "Finish training!"} + + src.Log.print_with_color("[>>>] Finish training!", "red") + self.send_to_server(notify_data) + + broadcast_queue_name = f'reply_{self.client_id}' + while True: + method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + if body: + received_data = pickle.loads(body) + src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue") + if received_data["action"] == "PAUSE": + return True, self.data_count + time.sleep(0.5) + + def train_on_last_layer(self, model, learning): + optimizer = optim.AdamW(model.parameters(), lr=learning["learning-rate"], weight_decay=learning["weight-decay"]) + result = True + + criterion = nn.CrossEntropyLoss() + forward_queue_name = f'intermediate_queue_{self.layer_id - 1}_{self.idx}' + + self.channel.queue_declare(queue=forward_queue_name, durable=False) + self.channel.basic_qos(prefetch_count=1) + print('Waiting for intermediate output. To exit press CTRL+C') + model.to(self.device) + + while True: + model.train() + optimizer.zero_grad() + + method_frame, header_frame, body = self.channel.basic_get(queue=forward_queue_name, auto_ack=True) + if method_frame and body: + received_data = pickle.loads(body) + intermediate_output_numpy = received_data["data"] + trace = received_data["trace"] + labels = received_data["label"].to(self.device) + + intermediate_output = torch.tensor(intermediate_output_numpy, requires_grad=True).to(self.device) + + output = model(intermediate_output) + + loss = criterion(output, labels) + print(f"Loss: {loss.item()}") + if torch.isnan(loss).any(): + src.Log.print_with_color("NaN detected in loss", "yellow") + result = False + + intermediate_output.retain_grad() + loss.backward() + optimizer.step() + self.data_count += 1 + gradient = intermediate_output.grad + self.send_gradient(gradient, trace) + else: + broadcast_queue_name = f'reply_{self.client_id}' + method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + if body: + received_data = pickle.loads(body) + src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue") + if received_data["action"] == "PAUSE": + return result, self.data_count + + time.sleep(0.5) diff --git a/other/2LS/src/train/VGG16.py b/other/2LS/src/train/VGG16.py new file mode 100644 index 0000000..a231ed0 --- /dev/null +++ b/other/2LS/src/train/VGG16.py @@ -0,0 +1,166 @@ +import time +import pickle +from tqdm import tqdm + +import torch +import torch.optim as optim +import torch.nn as nn + +import src.Log + +class Train_VGG16: + def __init__(self, client_id, layer_id, channel, device, in_cluster_id, idx): + self.client_id = client_id + self.layer_id = layer_id + self.channel = channel + self.device = device + self.in_cluster_id = in_cluster_id + self.idx = idx + self.data_count = 0 + self.size = None + + def send_intermediate_output(self, output, labels, trace): + forward_queue_name = f'intermediate_queue_{self.layer_id}_{self.idx}' + self.channel.queue_declare(forward_queue_name, durable=False) + + if trace: + trace.append(self.client_id) + message = pickle.dumps( + {"data": output.detach().cpu().numpy(), "label": labels, "trace": trace} + ) + else: + message = pickle.dumps( + {"data": output.detach().cpu().numpy(), "label": labels, "trace": [self.client_id]} + ) + + if self.size is None: + self.size = len(message) + print(f'Length message: {self.size} (bytes).') + + self.channel.basic_publish( + exchange='', + routing_key=forward_queue_name, + body=message + ) + + def send_gradient(self, gradient, trace): + to_client_id = trace[-1] + trace.pop(-1) + backward_queue_name = f'gradient_queue_{self.layer_id - 1}_{to_client_id}' + self.channel.queue_declare(queue=backward_queue_name, durable=False) + + message = pickle.dumps( + {"data": gradient.detach().cpu().numpy(), "trace": trace, "test": False}) + + self.channel.basic_publish( + exchange='', + routing_key=backward_queue_name, + body=message + ) + + def send_to_server(self, message): + self.channel.queue_declare('rpc_queue', durable=False) + self.channel.basic_publish(exchange='', + routing_key='rpc_queue', + body=pickle.dumps(message)) + + def train_on_first_layer(self, model, learning, train_loader=None): + optimizer = optim.SGD(model.parameters(), lr=learning["learning-rate"], momentum=learning["momentum"]) + + backward_queue_name = f'gradient_queue_{self.layer_id}_{self.client_id}' + self.channel.queue_declare(queue=backward_queue_name, durable=False) + self.channel.basic_qos(prefetch_count=1) + model.to(self.device) + + for batch in tqdm(train_loader, desc="Training"): + model.train() + optimizer.zero_grad() + training_data, labels = batch + training_data = training_data.to(self.device) + intermediate_output = model(training_data) + intermediate_output = intermediate_output.detach().requires_grad_(True) + + self.data_count += 1 + + self.send_intermediate_output(intermediate_output, labels, trace=None) + + while True: + method_frame, header_frame, body = self.channel.basic_get(queue=backward_queue_name, auto_ack=True) + if method_frame and body: + received_data = pickle.loads(body) + gradient_numpy = received_data["data"] + gradient = torch.tensor(gradient_numpy).to(self.device) + + output = model(training_data) + output.backward(gradient=gradient) + optimizer.step() + break + + else: + time.sleep(0.5) + + notify_data = {"action": "NOTIFY", "client_id": self.client_id, "layer_id": self.layer_id, + "in_cluster_id": self.in_cluster_id, + "message": "Finish training!"} + + src.Log.print_with_color("[>>>] Finish training!", "red") + self.send_to_server(notify_data) + + broadcast_queue_name = f'reply_{self.client_id}' + while True: + method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + if body: + received_data = pickle.loads(body) + src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue") + if received_data["action"] == "PAUSE": + return True , self.data_count + time.sleep(0.5) + + def train_on_last_layer(self, model, learning): + optimizer = optim.SGD(model.parameters(), lr=learning["learning-rate"], momentum=learning["momentum"]) + result = True + + criterion = nn.CrossEntropyLoss() + forward_queue_name = f'intermediate_queue_{self.layer_id - 1}_{self.idx}' + + self.channel.queue_declare(queue=forward_queue_name, durable=False) + self.channel.basic_qos(prefetch_count=1) + print('Waiting for intermediate output. To exit press CTRL+C') + model.to(self.device) + + while True: + model.train() + optimizer.zero_grad() + method_frame, header_frame, body = self.channel.basic_get(queue=forward_queue_name, auto_ack=True) + if method_frame and body: + received_data = pickle.loads(body) + intermediate_output_numpy = received_data["data"] + trace = received_data["trace"] + labels = received_data["label"].to(self.device) + + intermediate_output = torch.tensor(intermediate_output_numpy, requires_grad=True).to(self.device) + + output = model(intermediate_output) + + loss = criterion(output, labels) + print(f"Loss: {loss.item()}") + if torch.isnan(loss).any(): + src.Log.print_with_color("NaN detected in loss", "yellow") + result = False + + intermediate_output.retain_grad() + loss.backward() + optimizer.step() + self.data_count += 1 + gradient = intermediate_output.grad + self.send_gradient(gradient, trace) + + else: + broadcast_queue_name = f'reply_{self.client_id}' + method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + if body: + received_data = pickle.loads(body) + src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue") + if received_data["action"] == "PAUSE": + return result, self.data_count + time.sleep(0.5) \ No newline at end of file diff --git a/other/2LS/src/val/BERT.py b/other/2LS/src/val/BERT.py new file mode 100644 index 0000000..d26b717 --- /dev/null +++ b/other/2LS/src/val/BERT.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn +from tqdm import tqdm + +from src.model.BERT_AGNEWS import BERT_AGNEWS +from src.dataset.dataloader import data_loader + +def val_BERT(data_name, state_dict_full, logger): + criterion = nn.CrossEntropyLoss() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + test_loader = data_loader(data_name=data_name,train=False) + model = BERT_AGNEWS() + model = model.to(device) + model.load_state_dict(state_dict_full) + + model.eval() + correct, total, total_loss = 0, 0, 0 + + with torch.no_grad(): + for batch in tqdm(test_loader): + input_ids = batch['input_ids'].to(device) + labels = batch['labels'].to(device) + + logits = model(input_ids) + loss = criterion(logits, labels) + total_loss += loss.item() + correct += (logits.argmax(1) == labels).sum().item() + total += labels.size(0) + + acc = correct / total + avg_loss = total_loss / len(test_loader) + + print(f"Test Loss: {avg_loss:.2f}; Test Acc: {acc:.2f}") + + logger.log_info(f"Test Loss: {avg_loss:.2f}; Test Acc: {acc:.2f}") + + + + + + + + diff --git a/other/2LS/src/val/KWT.py b/other/2LS/src/val/KWT.py new file mode 100644 index 0000000..be0e004 --- /dev/null +++ b/other/2LS/src/val/KWT.py @@ -0,0 +1,39 @@ +from tqdm import tqdm +import torch +import torch.nn as nn + +from src.dataset.dataloader import data_loader +from src.model.KWT_SPEECHCOMMANDS import KWT_SPEECHCOMMANDS + +def val_KWT(data_name, state_dict_full, logger): + criterion = nn.CrossEntropyLoss() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(device) + test_loader = data_loader(data_name=data_name, train=False) + + model = KWT_SPEECHCOMMANDS() + model.load_state_dict(state_dict_full) + model.to(device) + model.eval() + + correct, total, total_loss = 0, 0, 0 + + with torch.no_grad(): + for mfcc, labels in tqdm(test_loader): + mfcc = mfcc.to(device) + labels = labels.to(device) + + outputs = model(mfcc) + loss = criterion(outputs, labels) + + total_loss += loss.item() + correct += (outputs.argmax(1) == labels).sum().item() + total += labels.size(0) + + acc = (correct / total) * 100 + avg_loss = total_loss / len(test_loader) + + print('Test set: Loss: {:.4f}; Accuracy: {}/{} ({:.2f}%)\n'.format(avg_loss, + correct, total, acc)) + logger.log_info('Test set: Loss: {:.4f}; Accuracy: {}/{} ({:.2f}%)\n'.format(avg_loss, + correct, total, acc)) \ No newline at end of file diff --git a/other/2LS/src/val/VGG16.py b/other/2LS/src/val/VGG16.py new file mode 100644 index 0000000..0d0b1dc --- /dev/null +++ b/other/2LS/src/val/VGG16.py @@ -0,0 +1,39 @@ +import torch +import torch.nn as nn + +from src.model.VGG16_CIFAR10 import VGG16_CIFAR10 +from tqdm import tqdm +from src.dataset.dataloader import data_loader + +def val_VGG16(data_name, state_dict_full, logger): + criterion = nn.CrossEntropyLoss() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(device) + test_loader = data_loader(data_name=data_name, train=False) + + model = VGG16_CIFAR10() + model = model.to(device) + model.load_state_dict(state_dict_full) + model.eval() + + correct, total, total_loss = 0, 0, 0 + + with torch.no_grad(): + for images, labels in tqdm(test_loader): + images = images.to(device) + labels = labels.to(device) + + outputs = model(images) + loss = criterion(outputs, labels) + + total_loss += loss.item() + correct += (outputs.argmax(1) == labels).sum().item() + total += labels.size(0) + + acc = (correct / total) * 100 + avg_loss = total_loss / len(test_loader) + + print('Test set:Loss: {:.4f}; Accuracy: {}/{} ({:.2f}%)\n'.format(avg_loss, + correct, total, acc)) + logger.log_info('Test set:Loss: {:.4f}; Accuracy: {}/{} ({:.2f}%)\n'.format(avg_loss, + correct, total, acc)) \ No newline at end of file diff --git a/other/2LS/src/val/__init__.py b/other/2LS/src/val/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/other/2LS/src/val/get_val.py b/other/2LS/src/val/get_val.py new file mode 100644 index 0000000..dee914a --- /dev/null +++ b/other/2LS/src/val/get_val.py @@ -0,0 +1,17 @@ +from src.val.VGG16 import val_VGG16 +from src.val.BERT import val_BERT +from src.val.KWT import val_KWT + +def get_val(model_name, data_name, state_dict_full, logger): + if model_name == 'BERT': + val_BERT(data_name, state_dict_full, logger) + return True + elif model_name == 'VGG16': + val_VGG16(data_name, state_dict_full, logger) + return True + elif model_name == 'KWT': + val_KWT(data_name, state_dict_full, logger) + return True + else: + return False +