Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions pcgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from operator import mul
from functools import reduce

def pc_grad_update(gradient_list):
import os
from multiprocessing import Pool

def pc_grad_update(gradient_list, num_workers=8):
'''
PyTorch implementation of PCGrad.
Gradient Surgery for Multi-Task Learning: https://arxiv.org/pdf/2001.06782.pdf
Expand All @@ -36,24 +39,15 @@ def pc_grad_update(gradient_list):
assert type(gradient_list) is list
assert len(gradient_list) != 0
num_tasks = len(gradient_list)
num_params = len(gradient_list[0])
np.random.shuffle(gradient_list)
# gradient_list = torch.stack(gradient_list)

# grad_dims = []
def flatten_and_store_dims(param_grad, acc):
output, grad_dim = acc

def flatten_and_store_dims(grad_task):
output = []
grad_dim = []
for param_grad in grad_task: # TODO(speedup): convert to map since they are faster
if grad_dim is not None:
grad_dim.append(tuple(param_grad.shape))
output.append(torch.flatten(param_grad))

# grad_dims.append(grad_dim)

return torch.cat(output), grad_dim

# gradient_list = list(map(flatten_and_store_dims, gradient_list))
output = torch.cat([output, torch.flatten(param_grad)])
return output, grad_dim

def restore_dims(grad_task, chunk_dims):
## chunk_dims is a list of tensor shapes
Expand All @@ -78,28 +72,34 @@ def project_gradients(grad_task):
Returns:
Component subtracted gradient
"""
grad_task, grad_dim = flatten_and_store_dims(grad_task)

for k in range(num_tasks): # TODO(speedup): convert to map since they are faster
conflict_gradient_candidate = gradient_list[k]
def get_projected_gradient_sum(k, conflict_gradient_candidate):
# no need to store dims of candidate since we are not changing it in the array
conflict_gradient_candidate, _ = flatten_and_store_dims(grad_task)
conflict_gradient_candidate, _ = reduce(flatten_and_store_dims, conflict_gradient_candidate, ([], None))

inner_product = torch.dot(torch.flatten(grad_task), torch.flatten(conflict_gradient_candidate))
# TODO(speedup): put conflict check condition here so that we aren't calculating norms for non-conflicting gradients
if inner_product >= 0.:
# print('conflict')
## no conflict, don't do heavy operations
return 0

proj_direction = inner_product / torch.norm(conflict_gradient_candidate)**2

## sanity check to see if there's any conflicting gradients
# if proj_direction < 0.:
# print('conflict')
# TODO(speedup): This is a cumulative subtraction, move to threaded in-memory map-reduce
grad_task = grad_task - min(proj_direction, 0.) * conflict_gradient_candidate

return proj_direction

grad_task, grad_dim = reduce(flatten_and_store_dims, grad_task, ([], []))
pool = Pool(num_workers)
# Note: A pool within a pool might not be a good idea because of the zombie reaping problem
results = pool.starmap(get_projected_gradient_sum, zip(list(range(num_tasks)), gradient_list))

grad_task = grad_task - reduce(lambda a, b: a + b, results)
# get back grad_task
grad_task = restore_dims(grad_task, grad_dim)

return grad_task

flattened_grad_task = list(map(project_gradients, gradient_list))
flatmappool = Pool(num_workers)
flattened_grad_task = list(flatmappool.map(project_gradients, gradient_list))

yield flattened_grad_task

Expand Down