Skip to content

Commit f9afe3d

Browse files
committedMar 24, 2022
update gaussian kernel
1 parent 6d15200 commit f9afe3d

File tree

2 files changed

+9
-27
lines changed

2 files changed

+9
-27
lines changed
 

‎model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,16 @@ def initialize(self, x):
3838
self.V = self.Wv(x)
3939
self.sigma = self.Ws(x)
4040

41+
@staticmethod
42+
def gaussian_kernel(mean, sigma):
43+
normalize = 1 / (math.sqrt(2 * torch.pi) * sigma)
44+
return normalize * torch.exp(-0.5 * (mean / sigma).pow(2))
45+
4146
def prior_association(self):
4247
p = torch.from_numpy(
4348
np.abs(np.indices((self.N, self.N))[0] - np.indices((self.N, self.N))[1])
4449
)
45-
gaussian = torch.normal(p.float(), self.sigma[:, 0].abs())
50+
gaussian = self.gaussian_kernel(p.float(), self.sigma)
4651
gaussian /= gaussian.sum(dim=-1).view(-1, 1)
4752

4853
return gaussian

‎train.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def train(config, model, train_data, val_data):
4040
print("Warmup steps: {}".format(warmup_steps))
4141

4242
num_steps = 0
43-
best_f1 = 0
4443
model.train()
4544

4645
for epoch in range(int(config.train.epochs)):
@@ -50,7 +49,8 @@ def train(config, model, train_data, val_data):
5049
outputs = model(batch)
5150
min_loss = model.min_loss(batch)
5251
max_loss = model.max_loss(batch)
53-
(min_loss - max_loss).backward()
52+
min_loss.backward(retain_graph=True)
53+
max_loss.backward()
5454

5555
torch.nn.utils.clip_grad_norm_(
5656
model.parameters(), config.train.max_grad_norm
@@ -64,32 +64,9 @@ def train(config, model, train_data, val_data):
6464
if not config.debug:
6565
wandb.log({"loss": loss.item()}, step=num_steps)
6666

67-
output = validate(config, model, val_data)
6867
if not config.debug:
6968
wandb.log(output, step=num_steps)
70-
71-
if output["validation_f1"] > best_f1:
72-
print(f"Best validation F1! Saving to {config.train.pt}")
73-
torch.save(model.state_dict(), config.train.pt)
74-
75-
best_f1 = max(best_f1, output["validation_f1"])
76-
77-
78-
def validate(config, model, data):
79-
80-
model.eval()
81-
with torch.no_grad():
82-
outputs = []
83-
for batch in tqdm(data):
84-
outputs.append(model(batch))
85-
outputs = torch.cat(outputs)
86-
outputs = outputs.cpu().numpy()
87-
outputs = np.argmax(outputs, axis=1)
88-
outputs = outputs.flatten()
89-
labels = data.dataset.labels.cpu().numpy()
90-
labels = labels.flatten()
91-
f1 = f1_score(labels, outputs, average="macro")
92-
return {"validation_f1": f1}
69+
torch.save(model.state_dict(), config.train.pt)
9370

9471

9572
@hydra.main(config_path="./conf", config_name="config")

0 commit comments

Comments
 (0)