Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions other/2LS/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pika
import uuid
import argparse
import yaml

import torch

import src.Log
from src.RpcClient import RpcClient
from src.Scheduler import Scheduler

parser = argparse.ArgumentParser(description="Split learning framework")
parser.add_argument('--layer_id', type=int, required=True, help='ID of layer, start from 1')
parser.add_argument('--device', type=str, required=False, help='Device of client')
# add new argument
parser.add_argument('--idx', type=int, required=True, help='index of client')
parser.add_argument('--incluster', type=int, required=False, default=0, help='In-cluster ID')
parser.add_argument('--outcluster', type=int, required=False, default=0, help='Out-cluster ID')
args = parser.parse_args()

with open('config.yaml', 'r') as file:
config = yaml.safe_load(file)

client_id = uuid.uuid4()
address = config["rabbit"]["address"]
username = config["rabbit"]["username"]
password = config["rabbit"]["password"]
virtual_host = config["rabbit"]["virtual-host"]

device = None
if args.device is None:
if torch.cuda.is_available():
device = "cuda"
print(f"Using device: {torch.cuda.get_device_name(device)}")
else:
device = "cpu"
print(f"Using device: CPU")
else:
device = args.device
print(f"Using device: {device}")

credentials = pika.PlainCredentials(username, password)
connection = pika.BlockingConnection(pika.ConnectionParameters(address, 5672, f'{virtual_host}', credentials))
channel = connection.channel()

in_cluster_id = args.incluster
out_cluster_id = args.outcluster
idx = args.idx

if __name__ == "__main__":
src.Log.print_with_color("[>>>] Client sending registration message to server...", "red")

data = {"action": "REGISTER", "client_id": client_id, "idx": idx, "layer_id": args.layer_id,
"in_cluster_id": in_cluster_id, "out_cluster_id": out_cluster_id, "message": "Hello from Client!"}

scheduler = Scheduler(client_id, args.layer_id, channel, device, in_cluster_id=in_cluster_id, idx=idx)

client = RpcClient(client_id, args.layer_id, channel, scheduler.train_on_device, device)
client.send_to_server(data)
client.wait_response()

41 changes: 41 additions & 0 deletions other/2LS/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: Split Learning
server:
local-round: 1
global-round: 1

clients:
- 1
- 1
no-cluster:
cut-layers: [1]
manual-cluster:
num-cluster: 1
cut-layers: [1]

model: VGG16
data-name: CIFAR10
parameters:
load: False
save: False
validation: False
data-distribution:
non-iid: False
num-sample: 5000
num-label: 10
dirichlet:
alpha: 1
random-seed: 1

rabbit:
address: 127.0.0.1
username: admin
password: admin
virtual-host: /

log_path: .
debug_mode: True

learning:
learning-rate: 0.01
momentum: 0.5
batch-size: 32
25 changes: 25 additions & 0 deletions other/2LS/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SERVER

```
python3 server.py

# SPLIT SERVER (Layer 2)
python3 client.py --layer_id 2 --idx 0 --incluster 0 --outcluster 0
python3 client.py --layer_id 2 --idx 1 --incluster 0 --outcluster 0
python3 client.py --layer_id 2 --idx 2 --incluster 1 --outcluster 0

# OUT-CLUSTER 0 - Layer 1
python3 client.py --layer_id 1 --idx 0 --incluster 0 --outcluster 0
python3 client.py --layer_id 1 --idx 1 --incluster 0 --outcluster 0
python3 client.py --layer_id 1 --idx 2 --incluster 1 --outcluster 0


# OUT-CLUSTER 1 - Layer 1
python3 client.py --layer_id 1 --idx 0 --incluster 0 --outcluster 1
python3 client.py --layer_id 1 --idx 1 --incluster 0 --outcluster 1
python3 client.py --layer_id 1 --idx 2 --incluster 1 --outcluster 1

# OUT-CLUSTER 2 - Layer 1
python3 client.py --layer_id 1 --idx 0 --incluster 0 --outcluster 2
python3 client.py --layer_id 1 --idx 1 --incluster 0 --outcluster 2
python3 client.py --layer_id 1 --idx 2 --incluster 1 --outcluster 2
32 changes: 32 additions & 0 deletions other/2LS/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import argparse
import sys
import signal
from src.Server import Server
from src.Utils import delete_old_queues
import src.Log
import yaml

parser = argparse.ArgumentParser(description="Split learning framework with controller.")

args = parser.parse_args()

with open('config.yaml') as file:
config = yaml.safe_load(file)
address = config["rabbit"]["address"]
username = config["rabbit"]["username"]
password = config["rabbit"]["password"]
virtual_host = config["rabbit"]["virtual-host"]


def signal_handler(sig, frame):
print("\nCatch stop signal Ctrl+C. Stop the program.")
delete_old_queues(address, username, password, virtual_host)
sys.exit(0)


if __name__ == "__main__":
signal.signal(signal.SIGINT, signal_handler)
delete_old_queues(address, username, password, virtual_host)
server = Server(config)
server.start()
src.Log.print_with_color("Ok, ready!", "green")
66 changes: 66 additions & 0 deletions other/2LS/src/Log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import logging
import os


class Colors:
COLORS = {
"header": '\033[95m',
"blue": '\033[94m',
"green": '\033[92m',
"yellow": '\033[93m',
"red": '\033[91m',
"end": '\033[0m'
}


class Logger:
def __init__(self, log_path, debug_mode=False, minimal=False):
# Thiết lập logger với tên "my_logger"
self.logger = logging.getLogger("my_logger")
self.logger.setLevel(logging.DEBUG) # Mức log
self.debug_mode = debug_mode
self.minimal = minimal

# Clear existing handlers to avoid duplicate logs if re-initialized
if self.logger.hasHandlers():
self.logger.handlers.clear()

if log_path:
# Tạo thư mục log nếu chưa tồn tại
log_dir = os.path.dirname(log_path)
if log_dir:
os.makedirs(log_dir, exist_ok=True)

# Tạo file handler để ghi log vào file
file_handler = logging.FileHandler(log_path)
file_handler.setLevel(logging.DEBUG)

# Định dạng log
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)

# Gắn file handler vào logger
self.logger.addHandler(file_handler)

def log_info(self, message):
if not self.minimal:
print(f"[INFO] {message}")
self.logger.info(message)

def log_warning(self, message):
print_with_color(f"[WARN] {message}", "yellow")
self.logger.warning(message)

def log_error(self, message):
print_with_color(f"[ERROR] {message}", "red")
self.logger.error(message)

def log_debug(self, message):
if self.debug_mode:
print_with_color(f"[DEBUG] {message}", "green")
self.logger.debug(message)


def print_with_color(text, color):
color_code = Colors.COLORS.get(color.lower(), Colors.COLORS["end"])
print(f"{color_code}{text}{Colors.COLORS['end']}")
Loading