diff --git a/pcgrad.py b/pcgrad.py index 9c6288b..845c9a1 100644 --- a/pcgrad.py +++ b/pcgrad.py @@ -20,7 +20,10 @@ from operator import mul from functools import reduce -def pc_grad_update(gradient_list): +import os +from multiprocessing import Pool + +def pc_grad_update(gradient_list, num_workers=8): ''' PyTorch implementation of PCGrad. Gradient Surgery for Multi-Task Learning: https://arxiv.org/pdf/2001.06782.pdf @@ -36,24 +39,15 @@ def pc_grad_update(gradient_list): assert type(gradient_list) is list assert len(gradient_list) != 0 num_tasks = len(gradient_list) - num_params = len(gradient_list[0]) np.random.shuffle(gradient_list) - # gradient_list = torch.stack(gradient_list) - # grad_dims = [] + def flatten_and_store_dims(param_grad, acc): + output, grad_dim = acc - def flatten_and_store_dims(grad_task): - output = [] - grad_dim = [] - for param_grad in grad_task: # TODO(speedup): convert to map since they are faster + if grad_dim is not None: grad_dim.append(tuple(param_grad.shape)) - output.append(torch.flatten(param_grad)) - - # grad_dims.append(grad_dim) - - return torch.cat(output), grad_dim - - # gradient_list = list(map(flatten_and_store_dims, gradient_list)) + output = torch.cat([output, torch.flatten(param_grad)]) + return output, grad_dim def restore_dims(grad_task, chunk_dims): ## chunk_dims is a list of tensor shapes @@ -78,28 +72,34 @@ def project_gradients(grad_task): Returns: Component subtracted gradient """ - grad_task, grad_dim = flatten_and_store_dims(grad_task) - for k in range(num_tasks): # TODO(speedup): convert to map since they are faster - conflict_gradient_candidate = gradient_list[k] + def get_projected_gradient_sum(k, conflict_gradient_candidate): # no need to store dims of candidate since we are not changing it in the array - conflict_gradient_candidate, _ = flatten_and_store_dims(grad_task) + conflict_gradient_candidate, _ = reduce(flatten_and_store_dims, conflict_gradient_candidate, ([], None)) inner_product = torch.dot(torch.flatten(grad_task), torch.flatten(conflict_gradient_candidate)) - # TODO(speedup): put conflict check condition here so that we aren't calculating norms for non-conflicting gradients + if inner_product >= 0.: + # print('conflict') + ## no conflict, don't do heavy operations + return 0 + proj_direction = inner_product / torch.norm(conflict_gradient_candidate)**2 - - ## sanity check to see if there's any conflicting gradients - # if proj_direction < 0.: - # print('conflict') - # TODO(speedup): This is a cumulative subtraction, move to threaded in-memory map-reduce - grad_task = grad_task - min(proj_direction, 0.) * conflict_gradient_candidate + + return proj_direction + grad_task, grad_dim = reduce(flatten_and_store_dims, grad_task, ([], [])) + pool = Pool(num_workers) + # Note: A pool within a pool might not be a good idea because of the zombie reaping problem + results = pool.starmap(get_projected_gradient_sum, zip(list(range(num_tasks)), gradient_list)) + + grad_task = grad_task - reduce(lambda a, b: a + b, results) # get back grad_task grad_task = restore_dims(grad_task, grad_dim) + return grad_task - flattened_grad_task = list(map(project_gradients, gradient_list)) + flatmappool = Pool(num_workers) + flattened_grad_task = list(flatmappool.map(project_gradients, gradient_list)) yield flattened_grad_task