-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathserver.py
More file actions
132 lines (113 loc) · 4.95 KB
/
server.py
File metadata and controls
132 lines (113 loc) · 4.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# server.py - My Flask "manager" for the U-shaped split learning.
# Holds the middle model layer, handles forward/backward requests from clients.
# Saves its weights locally after each update.
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state
from flax import serialization
import optax
import os
from flask import Flask, request, send_file, jsonify
# --- Server's Part of the Model ---
class ServerModel(nn.Module):
@nn.compact
def __call__(self, x): # Input: (batch, 14, 14, 16)
x = nn.Conv(features=32, kernel_size=(5, 5), padding='SAME')(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
return x # Output: (batch, 7, 7, 32)
# --- Server State ---
app = Flask(__name__)
SERVER_WEIGHTS_PATH = "server_weights.msgpack"
# --- ADD THIS LINE ---
SERVER_URL = "http://127.0.0.1:5000" # Define the URL here
# ---
server_state = None
# Cache for storing vjp functions between forward/backward
vjp_cache = {}
# Need batch size for deserialization, hardcoding for now
BATCH_SIZE = 64
# --- Load or Init Weights ---
def setup_server():
global server_state
server_model = ServerModel()
# Base state structure needed for deserialization
key = jax.random.PRNGKey(0)
dummy_input = jnp.ones([1, 14, 14, 16]) # Shape must match Client 1 output
params = server_model.init(key, dummy_input)
optimizer = optax.adam(0.001)
base_state = train_state.TrainState.create(apply_fn=server_model.apply, params=params, tx=optimizer)
if os.path.exists(SERVER_WEIGHTS_PATH):
print("Loading existing server weights...")
with open(SERVER_WEIGHTS_PATH, 'rb') as f:
model_bytes = f.read()
server_state = serialization.from_bytes(base_state, model_bytes)
else:
print("Initializing new server weights...")
server_state = base_state # Use the newly created state
model_bytes = serialization.to_bytes(server_state)
with open(SERVER_WEIGHTS_PATH, 'wb') as f:
f.write(model_bytes)
print("Server ready.")
# --- API Endpoint 1: Forward Pass ---
@app.route('/forward', methods=['POST'])
def forward_step():
global vjp_cache
smashed_data_bytes = request.data
# Need a target shape for deserialization
dummy_smashed_data = jnp.ones([BATCH_SIZE, 14, 14, 16], dtype=jnp.float32)
try:
smashed_data = serialization.from_bytes(dummy_smashed_data, smashed_data_bytes)
except Exception as e:
print(f"Error deserializing forward data: {e}")
return "Deserialization error", 400
# Run server's forward pass, get output and pullback function (vjp)
server_activations, server_vjp = jax.vjp(server_state.apply_fn, server_state.params, smashed_data)
# Store vjp for the backward pass, identified by data hash
step_id = str(hash(smashed_data_bytes))
vjp_cache[step_id] = server_vjp
# Send activations back to client
response_bytes = serialization.to_bytes(server_activations)
return response_bytes, 200, {'Content-Type': 'application/octet-stream', 'Step-ID': step_id}
# --- API Endpoint 2: Backward Pass ---
@app.route('/backward', methods=['POST'])
def backward_step():
global server_state, vjp_cache
step_id = request.headers.get('Step-ID')
server_activations_grads_bytes = request.data
# Target shape for gradients coming from Client 2
dummy_grads = jnp.ones([BATCH_SIZE, 7, 7, 32], dtype=jnp.float32)
try:
server_activations_grads = serialization.from_bytes(dummy_grads, server_activations_grads_bytes)
except Exception as e:
print(f"Error deserializing backward data: {e}")
return "Deserialization error", 400
# Get the correct pullback function using the ID
server_vjp = vjp_cache.pop(step_id, None)
if server_vjp is None:
print(f"Error: Step ID {step_id} not found.")
return "Error: Step ID not found", 400
# Calculate gradients for server weights and for smashed_data (to send back)
server_grads, smashed_data_grads = server_vjp(server_activations_grads)
# Update my weights
server_state = server_state.apply_gradients(grads=server_grads)
# Save updated weights
model_bytes = serialization.to_bytes(server_state)
with open(SERVER_WEIGHTS_PATH, 'wb') as f:
f.write(model_bytes)
# Send gradients for smashed_data back to Client 1
response_bytes = serialization.to_bytes(smashed_data_grads)
return response_bytes, 200, {'Content-Type': 'application/octet-stream'}
# --- API Endpoint 3: Share Weights for Testing ---
@app.route('/get_weights', methods=['GET'])
def get_weights():
if not os.path.exists(SERVER_WEIGHTS_PATH):
return "No weights saved yet.", 404
return send_file(SERVER_WEIGHTS_PATH)
# --- Start Server ---
if __name__ == '__main__':
setup_server()
# Now this print statement will work
print(f"--- JAX U-Shaped Server RUNNING on {SERVER_URL} ---")
app.run(host='127.0.0.1', port=5000)