@@ -40,7 +40,6 @@ def train(config, model, train_data, val_data):
40
40
print ("Warmup steps: {}" .format (warmup_steps ))
41
41
42
42
num_steps = 0
43
- best_f1 = 0
44
43
model .train ()
45
44
46
45
for epoch in range (int (config .train .epochs )):
@@ -50,7 +49,8 @@ def train(config, model, train_data, val_data):
50
49
outputs = model (batch )
51
50
min_loss = model .min_loss (batch )
52
51
max_loss = model .max_loss (batch )
53
- (min_loss - max_loss ).backward ()
52
+ min_loss .backward (retain_graph = True )
53
+ max_loss .backward ()
54
54
55
55
torch .nn .utils .clip_grad_norm_ (
56
56
model .parameters (), config .train .max_grad_norm
@@ -64,32 +64,9 @@ def train(config, model, train_data, val_data):
64
64
if not config .debug :
65
65
wandb .log ({"loss" : loss .item ()}, step = num_steps )
66
66
67
- output = validate (config , model , val_data )
68
67
if not config .debug :
69
68
wandb .log (output , step = num_steps )
70
-
71
- if output ["validation_f1" ] > best_f1 :
72
- print (f"Best validation F1! Saving to { config .train .pt } " )
73
- torch .save (model .state_dict (), config .train .pt )
74
-
75
- best_f1 = max (best_f1 , output ["validation_f1" ])
76
-
77
-
78
- def validate (config , model , data ):
79
-
80
- model .eval ()
81
- with torch .no_grad ():
82
- outputs = []
83
- for batch in tqdm (data ):
84
- outputs .append (model (batch ))
85
- outputs = torch .cat (outputs )
86
- outputs = outputs .cpu ().numpy ()
87
- outputs = np .argmax (outputs , axis = 1 )
88
- outputs = outputs .flatten ()
89
- labels = data .dataset .labels .cpu ().numpy ()
90
- labels = labels .flatten ()
91
- f1 = f1_score (labels , outputs , average = "macro" )
92
- return {"validation_f1" : f1 }
69
+ torch .save (model .state_dict (), config .train .pt )
93
70
94
71
95
72
@hydra .main (config_path = "./conf" , config_name = "config" )
0 commit comments