@@ -118,32 +118,32 @@ def block_diagonal(matrices):
118
118
return blocked
119
119
120
120
121
- def draw_heatmap (data , save_path , xlabels = None , ylabels = None ):
122
- # data = np.clip(data, -0.05, 0.05)
123
- cmap = cm .get_cmap ("rainbow" , 1000 )
124
- figure = plt .figure (facecolor = "w" )
125
- ax = figure .add_subplot (1 , 1 , 1 , position = [0.1 , 0.15 , 0.8 , 0.8 ])
126
- if xlabels is not None :
127
- ax .set_xticks (range (len (xlabels )))
128
- ax .set_xticklabels (xlabels )
129
- if ylabels is not None :
130
- ax .set_yticks (range (len (ylabels )))
131
- ax .set_yticklabels (ylabels )
132
-
133
- vmax = data [0 ][0 ]
134
- vmin = data [0 ][0 ]
135
- for i in data :
136
- for j in i :
137
- if j > vmax :
138
- vmax = j
139
- if j < vmin :
140
- vmin = j
141
- map = ax .imshow (
142
- data , interpolation = "nearest" , cmap = cmap , aspect = "auto" , vmin = vmin , vmax = vmax
143
- )
144
- plt .colorbar (mappable = map , cax = None , ax = None , shrink = 0.5 )
145
- plt .savefig (save_path )
146
- plt .close ()
121
+ # def draw_heatmap(data, save_path, xlabels=None, ylabels=None):
122
+ # # data = np.clip(data, -0.05, 0.05)
123
+ # cmap = cm.get_cmap("rainbow", 1000)
124
+ # figure = plt.figure(facecolor="w")
125
+ # ax = figure.add_subplot(1, 1, 1, position=[0.1, 0.15, 0.8, 0.8])
126
+ # if xlabels is not None:
127
+ # ax.set_xticks(range(len(xlabels)))
128
+ # ax.set_xticklabels(xlabels)
129
+ # if ylabels is not None:
130
+ # ax.set_yticks(range(len(ylabels)))
131
+ # ax.set_yticklabels(ylabels)
132
+
133
+ # vmax = data[0][0]
134
+ # vmin = data[0][0]
135
+ # for i in data:
136
+ # for j in i:
137
+ # if j > vmax:
138
+ # vmax = j
139
+ # if j < vmin:
140
+ # vmin = j
141
+ # map = ax.imshow(
142
+ # data, interpolation="nearest", cmap=cmap, aspect="auto", vmin=vmin, vmax=vmax
143
+ # )
144
+ # plt.colorbar(mappable=map, cax=None, ax=None, shrink=0.5)
145
+ # plt.savefig(save_path)
146
+ # plt.close()
147
147
148
148
149
149
def shape_mask (size , shape ):
@@ -452,39 +452,39 @@ def _draw_real_pred_pairs(real, pred, area_size: int):
452
452
ax .set_aspect (1 )
453
453
454
454
455
- # def draw_heatmap(weights):
456
- # # weights should a 4-D tensor: [M, N, H, W]
457
- # nrow, ncol = weights.shape[0], weights.shape[1]
458
- # fig = plt.figure(figsize=(ncol, nrow))
459
-
460
- # for i in range(nrow):
461
- # for j in range(ncol):
462
- # plt.subplot(nrow, ncol, i * ncol + j + 1)
463
- # weight = weights[i, j]
464
- # vmin, vmax = weight.min() - 0.01, weight.max()
465
-
466
- # cmap = cm.get_cmap("rainbow", 1000)
467
- # cmap.set_under("w")
468
-
469
- # plt.imshow(
470
- # weight,
471
- # interpolation="nearest",
472
- # cmap=cmap,
473
- # aspect="auto",
474
- # vmin=vmin,
475
- # vmax=vmax,
476
- # )
477
- # plt.axis("off")
478
-
479
- # fig.canvas.draw()
480
- # image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
481
- # image_from_plot = image_from_plot.reshape(
482
- # fig.canvas.get_width_height()[::-1] + (3,)
483
- # )
484
- # # plt.show()
485
- # plt.close(fig)
455
+ def draw_heatmap (weights ):
456
+ # weights should a 4-D tensor: [M, N, H, W]
457
+ nrow , ncol = weights .shape [0 ], weights .shape [1 ]
458
+ fig = plt .figure (figsize = (ncol , nrow ))
486
459
487
- # return np.expand_dims(image_from_plot, axis=0)
460
+ for i in range (nrow ):
461
+ for j in range (ncol ):
462
+ plt .subplot (nrow , ncol , i * ncol + j + 1 )
463
+ weight = weights [i , j ]
464
+ vmin , vmax = weight .min () - 0.01 , weight .max ()
465
+
466
+ cmap = cm .get_cmap ("rainbow" , 1000 )
467
+ cmap .set_under ("w" )
468
+
469
+ plt .imshow (
470
+ weight ,
471
+ interpolation = "nearest" ,
472
+ cmap = cmap ,
473
+ aspect = "auto" ,
474
+ vmin = vmin ,
475
+ vmax = vmax ,
476
+ )
477
+ plt .axis ("off" )
478
+
479
+ fig .canvas .draw ()
480
+ image_from_plot = np .frombuffer (fig .canvas .tostring_rgb (), dtype = np .uint8 )
481
+ image_from_plot = image_from_plot .reshape (
482
+ fig .canvas .get_width_height ()[::- 1 ] + (3 ,)
483
+ )
484
+ # plt.show()
485
+ plt .close (fig )
486
+
487
+ return np .expand_dims (image_from_plot , axis = 0 )
488
488
489
489
490
490
def average_appended_metrics (metrics ):
0 commit comments