Skip to content

Commit 6c76277

Browse files
authored
Update ResnetNetwork.py
1 parent 8d53055 commit 6c76277

File tree

1 file changed

+23
-13
lines changed

1 file changed

+23
-13
lines changed

ResnetNetwork.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ def __init__(self, lr_dataloader, lr_finder=True, num_classes=1000):
2020
self.lr_dataloader = lr_dataloader
2121
self.save_hyperparameters()
2222
self.model = models.resnet50(pretrained=False, num_classes=num_classes)
23+
# self.model.gradient_checkpointing_enable() # Enable gradient checkpointing
2324
# self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes)
2425
self.criterion = torch.nn.CrossEntropyLoss()
25-
self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)
26+
# self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)
2627
self.epoch_start_time = None
2728

2829
def forward(self, x):
@@ -35,32 +36,34 @@ def training_step(self, batch, batch_idx):
3536
data, target = batch
3637
output = self(data)
3738
loss = self.criterion(output, target)
38-
prediction = output.argmax(dim=1)
39-
acc = self.accuracy(prediction, target)
40-
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
41-
self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
39+
# prediction = output.argmax(dim=1)
40+
# acc = self.accuracy(prediction, target)
41+
acc = (output.argmax(dim=1) == target).float().mean()
42+
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
43+
self.log("train_acc", acc*100, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
4244
return loss
4345

4446
def on_train_epoch_end(self):
4547
epoch_duration = time.time() - self.epoch_start_time
46-
self.log('train_epoch_time', epoch_duration, on_epoch=True, prog_bar=True, logger=True)
48+
self.log('train_epoch_time', epoch_duration, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
4749

4850
def validation_step(self, batch, batch_idx):
4951
# Run validation only every 10 epochs and on the last epoch
5052
# if (self.current_epoch + 1) % 10 == 0 or (self.current_epoch + 1) == self.trainer.max_epochs:
5153
data, target = batch
5254
output = self(data)
5355
loss = self.criterion(output, target).item()
54-
prediction = output.argmax(dim=1)
55-
acc = self.accuracy(prediction, target)
56-
self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
57-
self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
56+
# prediction = output.argmax(dim=1)
57+
# acc = self.accuracy(prediction, target)
58+
acc = (output.argmax(dim=1) == target).float().mean()
59+
self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
60+
self.log("val_acc", acc*100, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
5861
return loss
5962

6063
def on_validation_epoch_end(self):
6164
if self.epoch_start_time:
6265
epoch_duration = time.time() - self.epoch_start_time
63-
self.log('val_epoch_time', epoch_duration, on_epoch=True, prog_bar=True, logger=True)
66+
self.log('val_epoch_time', epoch_duration, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
6467

6568
def configure_optimizers(self):
6669
if self.with_lr_finder:
@@ -78,7 +81,8 @@ def configure_optimizers(self):
7881
scheduler = OneCycleLR(
7982
optimizer,
8083
max_lr=self.max_lr,
81-
steps_per_epoch=len(self.lr_dataloader),
84+
# steps_per_epoch=len(self.lr_dataloader),
85+
steps_per_epoch=len(self.train_dataloader()),
8286
# self.trainer.estimated_stepping_batches // self.trainer.max_epochs,
8387
epochs=self.trainer.max_epochs,
8488
# pct_start=0.3,
@@ -96,9 +100,15 @@ def configure_optimizers(self):
96100
},
97101
}
98102

103+
def train_dataloader(self):
104+
if not self.trainer.train_dataloader:
105+
self.trainer.fit_loop.setup_data()
106+
107+
return self.trainer.train_dataloader
108+
99109
def find_lr(self, optimizer):
100110
lr_finder = LRFinder(self, optimizer, criterion=self.criterion)
101111
lr_finder.range_test(self.lr_dataloader, end_lr=10, num_iter=500)
102112
self.max_lr = lr_finder.plot()[-1] # Plot the loss vs learning rate
103113
print(f"Suggested learning rate: {self.max_lr}")
104-
lr_finder.reset()
114+
lr_finder.reset()

0 commit comments

Comments
 (0)