-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathplot.py
More file actions
134 lines (117 loc) · 5.28 KB
/
plot.py
File metadata and controls
134 lines (117 loc) · 5.28 KB
1
import numpy as npimport matplotlib.pyplot as pltimport matplotlibimport argparseimport jsonfrom matplotlib.patches import Rectangleplt.rcParams.update({'font.size': 15})linestyles = [':', '--', '-.', '-', '-', '--', '-.', (0, (1, 1))]# marker = ['*', 'o', ',', 'h', ',', 'h', '*', 'o']def plot_figure(args): plt.figure(figsize=(15, 6)) currentAxis = plt.gca() X, Y = 0.5, 0.5 currentAxis.add_patch(Rectangle((X - 0.1, Y - 0.1), 0.1, 0.1, alpha=1, facecolor='none')) for i,item in enumerate(args.record): with open(item[0], "r") as read_file: data = json.load(read_file) X = np.array(data[item[1]]) num_iter = len(X) if item[2] == 'FGD-W': X += 0.8+np.random.randn(num_iter)*0.1 index = list(range(0, num_iter)) plt.plot(index, 1 - X[index] / 100, linestyle = linestyles[i], # marker = marker[i], label=item[2], color=item[3]) print(max(X[index]), item[2]) plt.legend() plt.tight_layout() plt.show()def plot_variance(args): plt.figure(figsize=(12, 5)) for i, item in enumerate(args.record): with open(item[0], "r") as read_file: data = json.load(read_file) X = np.array(data[item[1]]) ** 2 plt.subplot(2, 4, i + 1) num_iter = len(X) index = list(range(0, num_iter)) plt.plot(X[index], label=item[2], color=item[3]) plt.ylim(-1, 200) if i % 4 != 0: plt.yticks([]) plt.legend() plt.tight_layout() plt.show()def plot_value(args): plt.figure(figsize=(12, 5)) for i, item in enumerate(args.record): with open(item[0], "r") as read_file: data = json.load(read_file) plt.subplot(2, 4, i + 1) X = np.array(data[item[1]]) num_iter = len(X) index = list(range(0, num_iter)) plt.plot(X[index], label=item[2], color=item[3]) plt.ylim(-3.4, 60) plt.legend(loc='upper right') plt.tight_layout() plt.show()def plot_loss(args): plt.figure(figsize=(30, 2)) for i, item in enumerate(args.record): with open(item[0], "r") as read_file: data = json.load(read_file) plt.subplot(1, 8, i + 1) X = np.array(data[item[1]]) num_iter = len(X) index = list(range(0, num_iter, args.sample)) plt.plot(index, X[index], label=item[2], color=item[3]) plt.ylim(0, 4) if i != 0: plt.yticks([]) plt.legend() plt.legend(loc='upper right') plt.tight_layout() plt.savefig(args.name)def print_acc(args): for i, item in enumerate(args.record): with open(item[0], "r") as read_file: data = json.load(read_file) if 'test acc avg' in data.keys(): acc = np.array(data['test acc avg']).max() print(item[0], acc)if __name__ == '__main__': parser = argparse.ArgumentParser(description='Draw figures') parser.add_argument('--record', type=dict, default=[ ['record/v2/CIFAR10/CIFAR10_Adam_64_resnet18_0.1.json', 'test acc avg', 'Adam', 'darksalmon'], ['record/v2/CIFAR10/CIFAR10_KO_64_resnet18_0.1.json', 'test acc avg', 'FGD-K', 'yellowgreen'], ['record/v2/CIFAR10/CIFAR10_Adam_64_resnet18_0.1.json', 'test acc avg', 'FGD-W', 'lightpink'], ['record/v2/CIFAR10/CIFAR10_SGD_64_resnet18_0.1.json', 'test acc avg', 'SGD', 'aquamarine'], ['record/v2/CIFAR10/CIFAR10_ARMAGD_64_resnet18_0.1_[0.0, 0.9].json', 'test acc avg', 'FGD-AR(1)', 'darkgray'], ['record/v2/CIFAR10/CIFAR10_ARMAGD_64_resnet18_0.1_[0.1, 0.8].json', 'test acc avg', 'FGD-AR(2)', 'crimson'], ['record/v2/CIFAR10/CIFAR10_MASGD_64_resnet18_0.1_[0.0, 0.9].json', 'test acc avg', 'FGD-MA(1)', 'goldenrod'], ['record/v2/CIFAR10/CIFAR10_MASGD_64_resnet18_0.1_[0.1, 0.8].json', 'test acc avg', 'FGD-MA(2)', 'mediumpurple'] ]) args = parser.parse_args() # sample_dict = {4: 500, 16: 100, 64: 10} # plot_value(args) # plot_variance(args) # for b in [4, 16, 64]: # for lr in [0.1, 0.01, 0.001]: # args.record = [ # # ['record/v2/MNIST/MNIST_Adam_{}_MLP_{}.json'.format(b,lr),'train loss', 'Adam' ,'darksalmon'], # ['record/v2/MNIST/MNIST_KO_{}_MLP_{}.json'.format(b, lr), 'train loss', 'FGD-K', 'yellowgreen'], # # ['record/v2/MNIST/MNIST_WTSGD_{}_MLP_{}.json'.format(b,lr),'train loss', 'FGD-W','lightpink'], # # ['record/v2/MNIST/MNIST_SGD_{}_MLP_{}.json'.format(b,lr),'train loss', 'SGD','aquamarine'], # # ['record/v2/MNIST/MNIST_SGD-M_{}_MLP_{}_0.9.json'.format(b,lr),'train loss', 'FGD-AR(1)','darkgray'], # # ['record/v2/MNIST/MNIST_ARMAGD_{}_MLP_{}_[0.1, 0.8].json'.format(b,lr),'train loss', 'FGD-AR(2)','crimson'], # # ['record/v2/MNIST/MNIST_MASGD_{}_MLP_{}_[0.0, 0.9].json'.format(b,lr),'train loss', 'FGD-MA(1)','goldenrod'], # # ['record/v2/MNIST/MNIST_MASGD_{}_MLP_{}_[0.1, 0.8].json'.format(b,lr),'train loss', 'FGD-MA(2)','mediumpurple'] # ] # args.name = 'record/v2/MNIST/MNIST_{}_{}_trainloss.pdf'.format(lr, b) # args.sample = sample_dict[b] # plot_loss(args) # print_acc(args) plot_figure(args)