-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsetup.py
More file actions
96 lines (81 loc) · 3.23 KB
/
setup.py
File metadata and controls
96 lines (81 loc) · 3.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# setup.py - Prepares MNIST data for JAX Split Learning & uploads to IPFS.
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import requests
import os
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# --- Config ---
NUM_CLIENTS = 3
DATA_DIR = './data_splits_jax'
# --- 1. Make sure data dir exists ---
if not os.path.exists(DATA_DIR):
os.makedirs(DATA_DIR)
print(f"Created directory: {DATA_DIR}")
# --- 2. Download MNIST ---
print("Downloading MNIST dataset...")
trainset_full = torchvision.datasets.MNIST(root='./data', train=True, download=True)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True)
# --- 3. Convert to NumPy (No Normalization) ---
def convert_to_numpy(dataset):
data = dataset.data.numpy()
labels = dataset.targets.numpy()
# Need channel last (N, H, W, C) for JAX conv layers
data = np.expand_dims(data, axis=-1).astype(np.float32)
# NOTE: Normalization removed as requested.
return data, labels
train_data, train_labels = convert_to_numpy(trainset_full)
test_data, test_labels = convert_to_numpy(testset)
print(f"Train data shape (unnormalized): {train_data.shape}")
# --- 4. Split into 4 parts (3 train, 1 test) ---
print(f"Splitting training data into {NUM_CLIENTS} parts...")
num_samples_per_client = len(train_data) // NUM_CLIENTS
indices = np.arange(len(train_data))
for i in range(NUM_CLIENTS):
start_idx = i * num_samples_per_client
end_idx = (i + 1) * num_samples_per_client
client_indices = indices[start_idx:end_idx]
client_data = train_data[client_indices]
client_labels = train_labels[client_indices]
# Save as compressed NumPy archive
file_path = os.path.join(DATA_DIR, f'client_{i+1}_data.npz')
np.savez_compressed(file_path, data=client_data, labels=client_labels)
print(f" - Saved {file_path}")
# Save the test set
test_file_path = os.path.join(DATA_DIR, 'test_data.npz')
np.savez_compressed(test_file_path, data=test_data, labels=test_labels)
print(f" - Saved {test_file_path}")
# --- 5. Upload to IPFS ---
def add_to_ipfs(file_path):
# Simple function to POST file to local IPFS daemon API
try:
with open(file_path, 'rb') as f:
files = {'file': f}
response = requests.post('http://127.0.0.1:5001/api/v0/add', files=files)
if response.status_code == 200:
return response.json()['Hash'] # Get the CID
else:
print(f"Error adding file: {response.text}")
return None
except requests.exceptions.ConnectionError:
print("\n--- IPFS CONNECTION ERROR ---")
print("Could not connect. Is 'ipfs daemon --offline' running?")
return None
print("\n--- Adding Client Data to IPFS ---")
cids = {}
for i in range(NUM_CLIENTS):
file_path = os.path.join(DATA_DIR, f'client_{i+1}_data.npz')
cid = add_to_ipfs(file_path)
if cid:
cids[f'client_{i+1}'] = cid
print(f"Client {i+1} data added. CID: {cid}")
test_cid = add_to_ipfs(test_file_path)
if test_cid:
cids['test_data'] = test_cid
print(f"Test data added. CID: {test_cid}")
if any(cids.values()):
print("\n--- SETUP COMPLETE ---")
else:
print("\n--- SETUP FAILED ---")