| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.optim as optim |
|
|
| import helpers as h |
| import domains |
| from domains import * |
| import math |
|
|
|
|
| POINT_DOMAINS = [m for m in h.getMethods(domains) if h.hasMethod(m, "attack")] + [ torch.FloatTensor, torch.Tensor, torch.cuda.FloatTensor ] |
| SYMETRIC_DOMAINS = [domains.Box] + POINT_DOMAINS |
|
|
| def domRes(outDom, target, **args): |
| t = h.one_hot(target.data.long(), outDom.size()[1]).to_dense() |
| tmat = t.unsqueeze(2).matmul(t.unsqueeze(1)) |
| |
| tl = t.unsqueeze(2).expand(-1, -1, tmat.size()[1]) |
| |
| inv_t = h.eye(tmat.size()[1]).expand(tmat.size()[0], -1, -1) |
| inv_t = inv_t - tmat |
| |
| tl = tl.bmm(inv_t) |
| |
| fst = outDom.bmm(tl) |
| snd = outDom.bmm(inv_t) |
| diff = fst - snd |
| return diff.lb() + t |
|
|
| def isSafeDom(outDom, target, **args): |
| od,_ = torch.min(domRes(outDom, target, **args), 1) |
| return od.gt(0.0).long().item() |
|
|
|
|
| def isSafeBox(target, net, inp, eps, dom): |
| atarg = target.argmax(1)[0].unsqueeze(0) |
| if hasattr(dom, "attack"): |
| x = dom.attack(net, eps, inp, target) |
| pred = net(x).argmax(1)[0].unsqueeze(0) |
| return pred.item() == atarg.item() |
| else: |
| outDom = net(dom.box(inp, eps)) |
| return isSafeDom(outDom, atarg) |
|
|