Skip to content

Commit 01e0080

Browse files
committedNov 10, 2021
train script; update configs
1 parent 9bf75b9 commit 01e0080

File tree

6 files changed

+215
-15
lines changed

6 files changed

+215
-15
lines changed
 

‎.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,5 @@ dmypy.json
127127

128128
# Pyre type checker
129129
.pyre/
130+
131+
.vscode

‎README.md

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,36 @@
1-
# anomaly_transformer_pytorch
1+
# Anomaly Transformer in PyTorch
2+
3+
This is an implementation of [Anomaly Transformer: Time Series Anomaly Detection with Association Discrepancy](https://arxiv.org/abs/2110.02642). This paper is currently [under review](https://openreview.net/forum?id=LzQQ89U1qm_) and in need of some clarification around the attention mechanism. This repo will be updated as more information is provided.
4+
5+
## Usage
6+
7+
### Requirements
8+
9+
Install dependences into a virtualenv:
10+
11+
```bash
12+
$ python -m venv env
13+
$ source env/bin/activate
14+
(env) $ pip install -r requirements.txt
15+
```
16+
17+
### Data and Configuration
18+
19+
Custom datasets can be placed in the `data/` dir. Edits should be made to the `conf/data/default.yaml` file to reflect the correct properties of the data. All other configuration hyperparameters can be set in the hydra configs.
20+
21+
### Train
22+
23+
Once properly configured, a model can be trained via `python train.py`.
24+
25+
## Citations
26+
27+
```bibtex
28+
@misc{xu2021anomaly,
29+
title={Anomaly Transformer: Time Series Anomaly Detection with Association Discrepancy},
30+
author={Jiehui Xu and Haixu Wu and Jianmin Wang and Mingsheng Long},
31+
year={2021},
32+
eprint={2110.02642},
33+
archivePrefix={arXiv},
34+
primaryClass={cs.LG}
35+
}
36+
```
File renamed without changes.

‎conf/config.yaml

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,51 @@
11
defaults:
2-
- model: transformers
2+
- model: default
3+
- data: default
34
- evaluate: default
45
- train: default
56

7+
seed: 0
8+
debug: False
9+
silent: False
10+
device: cuda
11+
12+
max_iters: 1000000
13+
log_interval: 100
14+
val_interval: 5000
15+
model_save_pt: 5000
16+
17+
lr: 1e-5
18+
batch_size: 32
19+
val_steps: 500
20+
grad_clip: 100.
21+
early_stop_patience: 20000
22+
early_stop_key: "loss/total_edit_val"
23+
dropout: 0.0
24+
results_dir: null
25+
26+
eval_only: False
27+
half: False
28+
save: False
29+
30+
model:
31+
pt: null
32+
33+
data:
34+
path: null
35+
rephrase: true
36+
zsre_nq: true
37+
nq_path: ${hydra:runtime.cwd}/data/nq
38+
wiki_webtext: true
39+
n_edits: 1
40+
41+
eval:
42+
verbose: True
43+
log_interval: 100
44+
final_eval: True
45+
646
hydra:
7-
auto: False
47+
run:
48+
dir: ./outputs/${now:%Y-%m-%d_%H-%M-%S_%f${uuid:}}
49+
sweep:
50+
dir: ./outputs/${now:%Y-%m-%d_%H-%M-%S_%f}
51+
subdir: ${hydra.job.num}

‎model.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88

99

1010
class AnomalyAttention(nn.Module):
11-
def __init__(self, seq_dim, channels):
11+
def __init__(self, seq_dim, in_channels, out_channels):
1212
super(AnomalyAttention, self).__init__()
13-
self.Q = self.K = self.V = self.sigma = torch.zeros((seq_dim, channels))
14-
self.d_model = channels
13+
self.W = nn.Linear(in_channels, out_channels, bias=False)
14+
self.Q = self.K = self.V = self.sigma = torch.zeros((seq_dim, out_channels))
15+
self.d_model = out_channels
1516
self.n = seq_dim
1617
self.P = torch.zeros((seq_dim, seq_dim))
1718
self.S = torch.zeros((seq_dim, seq_dim))
@@ -21,37 +22,43 @@ def forward(self, x):
2122
self.initialize(x)
2223
self.P = self.prior_association()
2324
self.S = self.series_association()
24-
print(self.S.shape)
25-
# assert self.S.shape == (self.n, self.n)
2625
Z = self.reconstruction()
2726

2827
return Z
2928

3029
def initialize(self, x):
3130
# self.d_model = x.shape[-1]
32-
self.Q = self.K = self.V = self.sigma = x
31+
self.Q = self.K = self.V = self.sigma = self.W(x)
32+
3333

3434
def prior_association(self):
35-
return torch.ones((self.n, self.n))
35+
p = torch.from_numpy(
36+
np.abs(
37+
np.indices((self.n,self.n))[0] -
38+
np.indices((self.n,self.n))[1]
39+
)
40+
)
41+
gaussian = torch.normal(p.float(), self.sigma[:,0].abs())
42+
gaussian /= gaussian.sum(dim=-1).view(-1, 1)
43+
44+
return gaussian
3645

3746
def series_association(self):
38-
print(self.Q.shape)
39-
print(self.K.shape)
4047
return F.softmax((self.Q @ self.K.T) / math.sqrt(self.d_model), dim=0)
4148

4249
def reconstruction(self):
4350
return self.S @ self.V
4451

4552
def association_discrepancy(self):
46-
return F.kl_div(self.P, self.S) + F.kl_div(self.S, self.P) #not going to be correct dimensions
53+
return F.kl_div(self.P, self.S) + F.kl_div(self.S, self.P)
4754

4855

4956
class AnomalyTransformerBlock(nn.Module):
5057
def __init__(self, seq_dim, feat_dim):
5158
super().__init__()
5259
self.seq_dim, self.feat_dim = seq_dim, feat_dim
5360

54-
self.attention = AnomalyAttention(self.seq_dim, self.feat_dim)
61+
self.attention = AnomalyAttention(self.seq_dim, self.feat_dim, self.feat_dim)
5562
self.ln1 = nn.LayerNorm(self.feat_dim)
5663
self.ff = nn.Sequential(
5764
nn.Linear(self.feat_dim, self.feat_dim),
@@ -94,7 +101,7 @@ def forward(self, x):
94101

95102
def loss(self, x):
96103
l2_norm = torch.linalg.matrix_norm(self.output - x, ord=2)
97-
return l2_norm + (lambda_ * self.assoc_discrepancy)
104+
return l2_norm + (self.lambda_ * self.assoc_discrepancy.mean())
98105

99106
def anomaly_score(self, x):
100107
score = F.softmax(-self.assoc_discrepancy, dim=0)

‎train.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import logging
2+
from datetime import datetime
3+
4+
import numpy as np
5+
import torch
6+
import wandb
7+
from tqdm import tqdm
8+
from torch.utils.data import DataLoader
9+
10+
import hydra
11+
from omegaconf import DictConfig
12+
from omegaconf.omegaconf import OmegaConf
13+
from transformers.optimization import AdamW, get_cosine_schedule_with_warmup
14+
15+
from model import AnomalyTransformer
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
def train(config, model, train_data, val_data):
21+
22+
train_dataloader = DataLoader(
23+
train_data,
24+
batch_size=config.train.batch_size,
25+
shuffle=config.train.shuffle,
26+
# collate_fn=collate_fn,
27+
drop_last=True,
28+
)
29+
total_steps = int(len(train_dataloader) * config.train.epochs)
30+
warmup_steps = max(int(total_steps * config.train.warmup_ratio), 200)
31+
optimizer = AdamW(
32+
model.parameters(),
33+
lr=config.train.lr,
34+
eps=config.train.adam_epsilon,
35+
)
36+
scheduler = get_cosine_schedule_with_warmup(
37+
optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
38+
)
39+
print("Total steps: {}".format(total_steps))
40+
print("Warmup steps: {}".format(warmup_steps))
41+
42+
num_steps = 0
43+
best_f1 = 0
44+
model.train()
45+
46+
for epoch in range(int(config.train.epochs)):
47+
model.zero_grad()
48+
for step, batch in enumerate(tqdm(train_dataloader)):
49+
50+
outputs = model(**inputs)
51+
loss = outputs.loss()
52+
loss.backward()
53+
54+
torch.nn.utils.clip_grad_norm_(
55+
model.parameters(), config.train.max_grad_norm
56+
)
57+
optimizer.step()
58+
scheduler.step()
59+
model.zero_grad()
60+
61+
num_steps += 1
62+
63+
if not config.debug:
64+
wandb.log({"loss": loss.item()}, step=num_steps)
65+
66+
output = validate(config, model, val_data)
67+
if not config.debug:
68+
wandb.log(output, step=num_steps)
69+
70+
if output["validation_f1"] > best_f1:
71+
print(f"Best validation F1! Saving to {config.train.pt}")
72+
torch.save(model.state_dict(), config.train.pt)
73+
74+
best_f1 = max(best_f1, output["validation_f1"])
75+
76+
77+
def validate(config, model, data):
78+
return 0
79+
80+
81+
@hydra.main(config_path="./conf", config_name="config")
82+
def main(config: DictConfig) -> None:
83+
84+
set_seed(config.train.state.seed)
85+
86+
logger.info(OmegaConf.to_yaml(config, resolve=True))
87+
logger.info(f"Using the model: {config.model.name}")
88+
89+
train_data, val_data = get_data(config)
90+
config.data.num_class = len(set([x["labels"] for x in train_features]))
91+
print(f"num_class: {config.data.num_class}")
92+
93+
if not config.debug:
94+
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
95+
run_name = f"{config.train.wandb.run_name}_{config.model.model}_{config.data.name}_{timestamp}"
96+
wandb.init(
97+
entity=config.train.wandb_entity,
98+
project=config.train.wandb_project,
99+
config=dict(config),
100+
name=run_name,
101+
)
102+
if not config.train.pt:
103+
config.train.pt = f"{config.train.pt}/{run_name}"
104+
105+
model = AnomalyTransformer(config)
106+
model.to(config.device)
107+
108+
train(config, model, train_data, val_data)
109+
110+
111+
if __name__ == "__main__":
112+
main()

0 commit comments

Comments
 (0)
Please sign in to comment.