@@ -20,9 +20,10 @@ def __init__(self, lr_dataloader, lr_finder=True, num_classes=1000):
2020 self .lr_dataloader = lr_dataloader
2121 self .save_hyperparameters ()
2222 self .model = models .resnet50 (pretrained = False , num_classes = num_classes )
23+ # self.model.gradient_checkpointing_enable() # Enable gradient checkpointing
2324 # self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes)
2425 self .criterion = torch .nn .CrossEntropyLoss ()
25- self .accuracy = Accuracy (task = "multiclass" , num_classes = num_classes )
26+ # self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)
2627 self .epoch_start_time = None
2728
2829 def forward (self , x ):
@@ -35,32 +36,34 @@ def training_step(self, batch, batch_idx):
3536 data , target = batch
3637 output = self (data )
3738 loss = self .criterion (output , target )
38- prediction = output .argmax (dim = 1 )
39- acc = self .accuracy (prediction , target )
40- self .log ("train_loss" , loss , on_step = True , on_epoch = True , prog_bar = True , logger = True )
41- self .log ("train_acc" , acc , on_step = True , on_epoch = True , prog_bar = True , logger = True )
39+ # prediction = output.argmax(dim=1)
40+ # acc = self.accuracy(prediction, target)
41+ acc = (output .argmax (dim = 1 ) == target ).float ().mean ()
42+ self .log ("train_loss" , loss , on_step = True , on_epoch = True , prog_bar = True , logger = True , sync_dist = True )
43+ self .log ("train_acc" , acc * 100 , on_step = True , on_epoch = True , prog_bar = True , logger = True , sync_dist = True )
4244 return loss
4345
4446 def on_train_epoch_end (self ):
4547 epoch_duration = time .time () - self .epoch_start_time
46- self .log ('train_epoch_time' , epoch_duration , on_epoch = True , prog_bar = True , logger = True )
48+ self .log ('train_epoch_time' , epoch_duration , on_epoch = True , prog_bar = True , logger = True , sync_dist = True )
4749
4850 def validation_step (self , batch , batch_idx ):
4951 # Run validation only every 10 epochs and on the last epoch
5052 # if (self.current_epoch + 1) % 10 == 0 or (self.current_epoch + 1) == self.trainer.max_epochs:
5153 data , target = batch
5254 output = self (data )
5355 loss = self .criterion (output , target ).item ()
54- prediction = output .argmax (dim = 1 )
55- acc = self .accuracy (prediction , target )
56- self .log ("val_loss" , loss , on_step = False , on_epoch = True , prog_bar = True , logger = True )
57- self .log ("val_acc" , acc , on_step = False , on_epoch = True , prog_bar = True , logger = True )
56+ # prediction = output.argmax(dim=1)
57+ # acc = self.accuracy(prediction, target)
58+ acc = (output .argmax (dim = 1 ) == target ).float ().mean ()
59+ self .log ("val_loss" , loss , on_step = False , on_epoch = True , prog_bar = True , logger = True , sync_dist = True )
60+ self .log ("val_acc" , acc * 100 , on_step = False , on_epoch = True , prog_bar = True , logger = True , sync_dist = True )
5861 return loss
5962
6063 def on_validation_epoch_end (self ):
6164 if self .epoch_start_time :
6265 epoch_duration = time .time () - self .epoch_start_time
63- self .log ('val_epoch_time' , epoch_duration , on_epoch = True , prog_bar = True , logger = True )
66+ self .log ('val_epoch_time' , epoch_duration , on_epoch = True , prog_bar = True , logger = True , sync_dist = True )
6467
6568 def configure_optimizers (self ):
6669 if self .with_lr_finder :
@@ -78,7 +81,8 @@ def configure_optimizers(self):
7881 scheduler = OneCycleLR (
7982 optimizer ,
8083 max_lr = self .max_lr ,
81- steps_per_epoch = len (self .lr_dataloader ),
84+ # steps_per_epoch=len(self.lr_dataloader),
85+ steps_per_epoch = len (self .train_dataloader ()),
8286 # self.trainer.estimated_stepping_batches // self.trainer.max_epochs,
8387 epochs = self .trainer .max_epochs ,
8488 # pct_start=0.3,
@@ -96,9 +100,15 @@ def configure_optimizers(self):
96100 },
97101 }
98102
103+ def train_dataloader (self ):
104+ if not self .trainer .train_dataloader :
105+ self .trainer .fit_loop .setup_data ()
106+
107+ return self .trainer .train_dataloader
108+
99109 def find_lr (self , optimizer ):
100110 lr_finder = LRFinder (self , optimizer , criterion = self .criterion )
101111 lr_finder .range_test (self .lr_dataloader , end_lr = 10 , num_iter = 500 )
102112 self .max_lr = lr_finder .plot ()[- 1 ] # Plot the loss vs learning rate
103113 print (f"Suggested learning rate: { self .max_lr } " )
104- lr_finder .reset ()
114+ lr_finder .reset ()
0 commit comments