Some helper functions for PyTorch

  1 '''Some helper functions for PyTorch, including:
  2     - get_mean_and_std: calculate the mean and std value of dataset.
  3     - msr_init: net parameter initialization.
  4     - progress_bar: progress bar mimic xlua.progress.
  5 '''
  6 import os
  7 import sys
  8 import time
  9 import math
 10  
 11 import torch.nn as nn
 12 import torch.nn.init as init
 13  
 14  
 15 def get_mean_and_std(dataset):
 16     '''Compute the mean and std value of dataset.'''
 17     dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
 18     mean = torch.zeros(3)
 19     std = torch.zeros(3)
 20     print('==> Computing mean and std..')
 21     for inputs, targets in dataloader:
 22         for i in range(3):
 23             mean[i] += inputs[:,i,:,:].mean()
 24             std[i] += inputs[:,i,:,:].std()
 25     mean.div_(len(dataset))
 26     std.div_(len(dataset))
 27     return mean, std
 28  
 29 def init_params(net):
 30     '''Init layer parameters.'''
 31     for m in net.modules():
 32         if isinstance(m, nn.Conv2d):
 33             init.kaiming_normal(m.weight, mode='fan_out')
 34             if m.bias:
 35                 init.constant(m.bias, 0)
 36         elif isinstance(m, nn.BatchNorm2d):
 37             init.constant(m.weight, 1)
 38             init.constant(m.bias, 0)
 39         elif isinstance(m, nn.Linear):
 40             init.normal(m.weight, std=1e-3)
 41             if m.bias:
 42                 init.constant(m.bias, 0)
 43  
 44  
 45 _, term_width = os.popen('stty size', 'r').read().split()
 46 term_width = int(term_width)
 47  
 48 TOTAL_BAR_LENGTH = 65.
 49 last_time = time.time()
 50 begin_time = last_time
 51 def progress_bar(current, total, msg=None):
 52     global last_time, begin_time
 53     if current == 0:
 54         begin_time = time.time()  # Reset for new bar.
 55  
 56     cur_len = int(TOTAL_BAR_LENGTH*current/total)
 57     rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
 58  
 59     sys.stdout.write(' [')
 60     for i in range(cur_len):
 61         sys.stdout.write('=')
 62     sys.stdout.write('>')
 63     for i in range(rest_len):
 64         sys.stdout.write('.')
 65     sys.stdout.write(']')
 66  
 67     cur_time = time.time()
 68     step_time = cur_time - last_time
 69     last_time = cur_time
 70     tot_time = cur_time - begin_time
 71  
 72     L = []
 73     L.append('  Step: %s' % format_time(step_time))
 74     L.append(' | Tot: %s' % format_time(tot_time))
 75     if msg:
 76         L.append(' | ' + msg)
 77  
 78     msg = ''.join(L)
 79     sys.stdout.write(msg)
 80     for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
 81         sys.stdout.write(' ')
 82  
 83     # Go back to the center of the bar.
 84     for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
 85         sys.stdout.write('')
 86     sys.stdout.write(' %d/%d ' % (current+1, total))
 87  
 88     if current < total-1:
 89         sys.stdout.write('
')
 90     else:
 91         sys.stdout.write('
')
 92     sys.stdout.flush()
 93  
 94 def format_time(seconds):
 95     days = int(seconds / 3600/24)
 96     seconds = seconds - days*3600*24
 97     hours = int(seconds / 3600)
 98     seconds = seconds - hours*3600
 99     minutes = int(seconds / 60)
100     seconds = seconds - minutes*60
101     secondsf = int(seconds)
102     seconds = seconds - secondsf
103     millis = int(seconds*1000)
104  
105     f = ''
106     i = 1
107     if days > 0:
108         f += str(days) + 'D'
109         i += 1
110     if hours > 0 and i <= 2:
111         f += str(hours) + 'h'
112         i += 1
113     if minutes > 0 and i <= 2:
114         f += str(minutes) + 'm'
115         i += 1
116     if secondsf > 0 and i <= 2:
117         f += str(secondsf) + 's'
118         i += 1
119     if millis > 0 and i <= 2:
120         f += str(millis) + 'ms'
121         i += 1
122     if f == '':
123         f = '0ms'
124     return f
原文地址:https://www.cnblogs.com/jiangkejie/p/11201133.html