| import future |
| import builtins |
| import past |
| import six |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.optim as optim |
| import torch.autograd |
| import components as comp |
| from torch.distributions import multinomial, categorical |
|
|
| import math |
| import numpy as np |
|
|
| try: |
| from . import helpers as h |
| from . import ai |
| from . import scheduling as S |
| except: |
| import helpers as h |
| import ai |
| import scheduling as S |
|
|
|
|
|
|
| class WrapDom(object): |
| def __init__(self, a): |
| self.a = eval(a) if type(a) is str else a |
|
|
| def box(self, *args, **kargs): |
| return self.Domain(self.a.box(*args, **kargs)) |
|
|
| def boxBetween(self, *args, **kargs): |
| return self.Domain(self.a.boxBetween(*args, **kargs)) |
|
|
| def line(self, *args, **kargs): |
| return self.Domain(self.a.line(*args, **kargs)) |
|
|
| class DList(object): |
| Domain = ai.ListDomain |
| class MLoss(): |
| def __init__(self, aw): |
| self.aw = aw |
| def loss(self, dom, *args, lr = 1, **kargs): |
| if self.aw <= 0.0: |
| return 0 |
| return self.aw * dom.loss(*args, lr = lr * self.aw, **kargs) |
|
|
| def __init__(self, *al): |
| if len(al) == 0: |
| al = [("Point()", 1.0), ("Box()", 0.1)] |
|
|
| self.al = [(eval(a) if type(a) is str else a, S.Const.initConst(aw)) for a,aw in al] |
|
|
| def getDiv(self, **kargs): |
| return 1.0 / sum(aw.getVal(**kargs) for _,aw in self.al) |
|
|
| def box(self, *args, **kargs): |
| m = self.getDiv(**kargs) |
| return self.Domain(ai.TaggedDomain(a.box(*args, **kargs), DList.MLoss(aw.getVal(**kargs) * m)) for a,aw in self.al) |
|
|
| def boxBetween(self, *args, **kargs): |
| |
| m = self.getDiv(**kargs) |
| return self.Domain(ai.TaggedDomain(a.boxBetween(*args, **kargs), DList.MLoss(aw.getVal(**kargs) * m)) for a,aw in self.al) |
|
|
| def line(self, *args, **kargs): |
| m = self.getDiv(**kargs) |
| return self.Domain(ai.TaggedDomain(a.line(*args, **kargs), DList.MLoss(aw.getVal(**kargs) * m)) for a,aw in self.al) |
| |
| def __str__(self): |
| return "DList(%s)" % h.sumStr("("+str(a)+","+str(w)+")" for a,w in self.al) |
|
|
| class Mix(DList): |
| def __init__(self, a="Point()", b="Box()", aw = 1.0, bw = 0.1): |
| super(Mix, self).__init__((a,aw), (b,bw)) |
|
|
| class LinMix(DList): |
| def __init__(self, a="Point()", b="Box()", bw = 0.1): |
| super(LinMix, self).__init__((a,S.Complement(bw)), (b,bw)) |
|
|
| class DProb(object): |
| def __init__(self, *doms): |
| if len(doms) == 0: |
| doms = [("Point()", 0.8), ("Box()", 0.2)] |
| div = 1.0 / sum(float(aw) for _,aw in doms) |
| self.domains = [eval(a) if type(a) is str else a for a,_ in doms] |
| self.probs = [ div * float(aw) for _,aw in doms] |
|
|
| def chooseDom(self): |
| return self.domains[np.random.choice(len(self.domains), p = self.probs)] if len(self.domains) > 1 else self.domains[0] |
|
|
| def box(self, *args, **kargs): |
| domain = self.chooseDom() |
| return domain.box(*args, **kargs) |
|
|
| def line(self, *args, **kargs): |
| domain = self.chooseDom() |
| return domain.line(*args, **kargs) |
|
|
| def __str__(self): |
| return "DProb(%s)" % h.sumStr("("+str(a)+","+str(w)+")" for a,w in zip(self.domains, self.probs)) |
|
|
| class Coin(DProb): |
| def __init__(self, a="Point()", b="Box()", ap = 0.8, bp = 0.2): |
| super(Coin, self).__init__((a,ap), (b,bp)) |
|
|
| class Point(object): |
| Domain = h.dten |
| def __init__(self, **kargs): |
| pass |
|
|
| def box(self, original, *args, **kargs): |
| return original |
|
|
| def line(self, original, other, *args, **kargs): |
| return (original + other) / 2 |
|
|
| def boxBetween(self, o1, o2, *args, **kargs): |
| return (o1 + o2) / 2 |
|
|
| def __str__(self): |
| return "Point()" |
|
|
| class PointA(Point): |
| def boxBetween(self, o1, o2, *args, **kargs): |
| return o1 |
|
|
| def __str__(self): |
| return "PointA()" |
|
|
| class PointB(Point): |
| def boxBetween(self, o1, o2, *args, **kargs): |
| return o2 |
|
|
| def __str__(self): |
| return "PointB()" |
|
|
|
|
| class NormalPoint(Point): |
| def __init__(self, w = None, **kargs): |
| self.epsilon = w |
| |
| def box(self, original, w, *args, **kargs): |
| """ original = mu = mean, epsilon = variance""" |
| if not self.epsilon is None: |
| w = self.epsilon |
|
|
| inter = torch.randn_like(original, device = h.device) * w |
| return original + inter |
|
|
| def __str__(self): |
| return "NormalPoint(%s)" % ("" if self.epsilon is None else str(self.epsilon)) |
|
|
|
|
|
|
| class MI_FGSM(Point): |
|
|
| def __init__(self, w = None, r = 20.0, k = 100, mu = 0.8, should_end = True, restart = None, searchable=False,**kargs): |
| self.epsilon = S.Const.initConst(w) |
| self.k = k |
| self.mu = mu |
| self.r = float(r) |
| self.should_end = should_end |
| self.restart = restart |
| self.searchable = searchable |
|
|
| def box(self, original, model, target = None, untargeted = False, **kargs): |
| if target is None: |
| untargeted = True |
| with torch.no_grad(): |
| target = model(original).max(1)[1] |
| return self.attack(model, original, untargeted, target, **kargs) |
|
|
| def boxBetween(self, o1, o2, model, target = None, *args, **kargs): |
| return self.attack(model, (o1 - o2).abs() / 2, (o1 + o2) / 2, target, **kargs) |
|
|
|
|
| def attack(self, model, xo, untargeted, target, w, loss_function=ai.stdLoss, **kargs): |
| w = self.epsilon.getVal(c = w, **kargs) |
|
|
| x = nn.Parameter(xo.clone(), requires_grad=True) |
| gradorg = h.zeros(x.shape) |
| is_eq = 1 |
|
|
| w = h.ones(x.shape) * w |
| for i in range(self.k): |
| if self.restart is not None and i % int(self.k / self.restart) == 0: |
| x = is_eq * (torch.rand_like(xo) * w + xo) + (1 - is_eq) * x |
| x = nn.Parameter(x, requires_grad = True) |
|
|
| model.optimizer.zero_grad() |
|
|
| out = model(x).vanillaTensorPart() |
| loss = loss_function(out, target) |
|
|
| loss.sum().backward(retain_graph=True) |
| with torch.no_grad(): |
| oth = x.grad / torch.norm(x.grad, p=1) |
| gradorg *= self.mu |
| gradorg += oth |
| grad = (self.r * w / self.k) * ai.mysign(gradorg) |
| if self.should_end: |
| is_eq = ai.mulIfEq(grad, out, target) |
| x = (x + grad * is_eq) if untargeted else (x - grad * is_eq) |
|
|
| x = xo + torch.min(torch.max(x - xo, -w),w) |
| x.requires_grad_() |
|
|
| model.optimizer.zero_grad() |
|
|
| return x |
|
|
| def boxBetween(self, o1, o2, model, target, *args, **kargs): |
| raise "Not boxBetween is not yet supported by MI_FGSM" |
|
|
| def __str__(self): |
| return "MI_FGSM(%s)" % (("" if self.epsilon is None else "w="+str(self.epsilon)+",") |
| + ("" if self.k == 5 else "k="+str(self.k)+",") |
| + ("" if self.r == 5.0 else "r="+str(self.r)+",") |
| + ("" if self.mu == 0.8 else "r="+str(self.mu)+",") |
| + ("" if self.should_end else "should_end=False")) |
|
|
|
|
| class PGD(MI_FGSM): |
| def __init__(self, r = 5.0, k = 5, **kargs): |
| super(PGD,self).__init__(r=r, k = k, mu = 0, **kargs) |
|
|
| def __str__(self): |
| return "PGD(%s)" % (("" if self.epsilon is None else "w="+str(self.epsilon)+",") |
| + ("" if self.k == 5 else "k="+str(self.k)+",") |
| + ("" if self.r == 5.0 else "r="+str(self.r)+",") |
| + ("" if self.should_end else "should_end=False")) |
|
|
| class IFGSM(PGD): |
|
|
| def __init__(self, k = 5, **kargs): |
| super(IFGSM, self).__init__(r = 1, k=k, **kargs) |
|
|
| def __str__(self): |
| return "IFGSM(%s)" % (("" if self.epsilon is None else "w="+str(self.epsilon)+",") |
| + ("" if self.k == 5 else "k="+str(self.k)+",") |
| + ("" if self.should_end else "should_end=False")) |
|
|
| class NormalAdv(Point): |
| def __init__(self, a="IFGSM()", w = None): |
| self.a = (eval(a) if type(a) is str else a) |
| self.epsilon = S.Const.initConst(w) |
|
|
| def box(self, original, w, *args, **kargs): |
| epsilon = self.epsilon.getVal(c = w, shape = original.shape[:1], **kargs) |
| assert (0 <= h.dten(epsilon)).all() |
| epsilon = torch.randn(original.size()[0:1], device = h.device)[0] * epsilon |
| return self.a.box(original, w = epsilon, *args, **kargs) |
|
|
| def __str__(self): |
| return "NormalAdv(%s)" % ( str(self.a) + ("" if self.epsilon is None else ",w="+str(self.epsilon))) |
|
|
|
|
| class InclusionSample(Point): |
| def __init__(self, sub, a="Box()", normal = False, w = None, **kargs): |
| self.sub = S.Const.initConst(sub) |
| self.w = S.Const.initConst(w) |
| self.normal = normal |
| self.a = (eval(a) if type(a) is str else a) |
|
|
| def box(self, original, w, *args, **kargs): |
| w = self.w.getVal(c = w, shape = original.shape[:1], **kargs) |
| sub = self.sub.getVal(c = 1, shape = original.shape[:1], **kargs) |
|
|
| assert (0 <= h.dten(w)).all() |
| assert (h.dten(sub) <= 1).all() |
| assert (0 <= h.dten(sub)).all() |
| if self.normal: |
| inter = torch.randn_like(original, device = h.device) |
| else: |
| inter = (torch.rand_like(original, device = h.device) * 2 - 1) |
|
|
| inter = inter * w * (1 - sub) |
| |
| return self.a.box(original + inter, w = w * sub, *args, **kargs) |
|
|
| def boxBetween(self, o1, o2, *args, **kargs): |
| w = (o2 - o1).abs() |
| return self.box( (o2 + o1)/2 , w = w, *args, **kargs) |
|
|
| def __str__(self): |
| return "InclusionSample(%s, %s)" % (str(self.sub), str(self.a) + ("" if self.epsilon is None else ",w="+str(self.epsilon))) |
|
|
| InSamp = InclusionSample |
|
|
|
|
| class AdvInclusion(InclusionSample): |
| def __init__(self, sub, a="IFGSM()", b="Box()", w = None, **kargs): |
| self.sub = S.Const.initConst(sub) |
| self.w = S.Const.initConst(w) |
| self.a = (eval(a) if type(a) is str else a) |
| self.b = (eval(b) if type(b) is str else b) |
|
|
| def box(self, original, w, *args, **kargs): |
| w = self.w.getVal(c = w, shape = original.shape, **kargs) |
| sub = self.sub.getVal(c = 1, shape = original.shape, **kargs) |
|
|
| assert (0 <= h.dten(w)).all() |
| assert (h.dten(sub) <= 1).all() |
| assert (0 <= h.dten(sub)).all() |
|
|
| if h.dten(w).sum().item() <= 0.0: |
| inter = original |
| else: |
| inter = self.a.box(original, w = w * (1 - sub), *args, **kargs) |
|
|
| return self.b.box(inter, w = w * sub, *args, **kargs) |
|
|
| def __str__(self): |
| return "AdvInclusion(%s, %s, %s)" % (str(self.sub), str(self.a), str(self.b) + ("" if self.epsilon is None else ",w="+str(self.epsilon))) |
|
|
|
|
| class AdvDom(Point): |
| def __init__(self, a="IFGSM()", b="Box()"): |
| self.a = (eval(a) if type(a) is str else a) |
| self.b = (eval(b) if type(b) is str else b) |
|
|
| def box(self, original,*args, **kargs): |
| adv = self.a.box(original, *args, **kargs) |
| return self.b.boxBetween(original, adv.ub(), *args, **kargs) |
|
|
| def boxBetween(self, o1, o2, *args, **kargs): |
| original = (o1 + o2) / 2 |
| adv = self.a.boxBetween(o1, o2, *args, **kargs) |
| return self.b.boxBetween(original, adv.ub(), *args, **kargs) |
|
|
| def __str__(self): |
| return "AdvDom(%s)" % (("" if self.width is None else "width="+str(self.width)+",") |
| + str(self.a) + "," + str(self.b)) |
|
|
|
|
|
|
| class BiAdv(AdvDom): |
| def box(self, original, **kargs): |
| adv = self.a.box(original, **kargs) |
| extreme = (adv.ub() - original).abs() |
| return self.b.boxBetween(original - extreme, original + extreme, **kargs) |
| |
| def boxBetween(self, o1, o2, *args, **kargs): |
| original = (o1 + o2) / 2 |
| adv = self.a.boxBetween(o1, o2, *args, **kargs) |
| extreme = (adv.ub() - original).abs() |
| return self.b.boxBetween(original - extreme, original + extreme, *args, **kargs) |
|
|
| def __str__(self): |
| return "BiAdv" + AdvDom.__str__(self)[6:] |
|
|
|
|
| class HBox(object): |
| Domain = ai.HybridZonotope |
|
|
| def domain(self, *args, **kargs): |
| return ai.TaggedDomain(self.Domain(*args, **kargs), self) |
|
|
| def __init__(self, w = None, tot_weight = 1, width_weight = 0, pow_loss = None, log_loss = False, searchable = True, cross_loss = True, **kargs): |
| self.w = S.Const.initConst(w) |
| self.tot_weight = S.Const.initConst(tot_weight) |
| self.width_weight = S.Const.initConst(width_weight) |
| self.pow_loss = pow_loss |
| self.searchable = searchable |
| self.log_loss = log_loss |
| self.cross_loss = cross_loss |
|
|
| def __str__(self): |
| return "HBox(%s)" % ("" if self.w is None else "w="+str(self.w)) |
|
|
| def boxBetween(self, o1, o2, *args, **kargs): |
| batches = o1.size()[0] |
| num_elem = h.product(o1.size()[1:]) |
| ei = h.getEi(batches, num_elem) |
| |
| if len(o1.size()) > 2: |
| ei = ei.contiguous().view(num_elem, *o1.size()) |
|
|
| return self.domain((o1 + o2) / 2, None, ei * (o2 - o1).abs() / 2).checkSizes() |
|
|
| def box(self, original, w, **kargs): |
| """ |
| This version of it is slow, but keeps correlation down the line. |
| """ |
| radius = self.w.getVal(c = w, **kargs) |
|
|
| batches = original.size()[0] |
| num_elem = h.product(original.size()[1:]) |
| ei = h.getEi(batches,num_elem) |
| |
| if len(original.size()) > 2: |
| ei = ei.contiguous().view(num_elem, *original.size()) |
|
|
| return self.domain(original, None, ei * radius).checkSizes() |
|
|
| def line(self, o1, o2, **kargs): |
| w = self.w.getVal(c = 0, **kargs) |
|
|
| ln = ((o2 - o1) / 2).unsqueeze(0) |
| if not w is None and w > 0.0: |
| batches = o1.size()[0] |
| num_elem = h.product(o1.size()[1:]) |
| ei = h.getEi(batches,num_elem) |
| if len(o1.size()) > 2: |
| ei = ei.contiguous().view(num_elem, *o1.size()) |
| ln = torch.cat([ln, ei * w]) |
| return self.domain((o1 + o2) / 2, None, ln ).checkSizes() |
|
|
| def loss(self, dom, target, *args, **kargs): |
| width_weight = self.width_weight.getVal(**kargs) |
| tot_weight = self.tot_weight.getVal(**kargs) |
| |
| if self.cross_loss: |
| r = dom.ub() |
| inds = torch.arange(r.shape[0], device=h.device, dtype=h.ltype) |
| r[inds,target] = dom.lb()[inds,target] |
| tot = r.loss(target, *args, **kargs) |
| else: |
| tot = dom.loss(target, *args, **kargs) |
|
|
| if self.log_loss: |
| tot = (tot + 1).log() |
| if self.pow_loss is not None and self.pow_loss > 0 and self.pow_loss != 1: |
| tot = tot.pow(self.pow_loss) |
|
|
| ls = tot * tot_weight |
| if width_weight > 0: |
| ls += dom.diameter() * width_weight |
|
|
| return ls / (width_weight + tot_weight) |
|
|
| class Box(HBox): |
| def __str__(self): |
| return "Box(%s)" % ("" if self.w is None else "w="+str(self.w)) |
|
|
| def box(self, original, w, **kargs): |
| """ |
| This version of it takes advantage of betas being uncorrelated. |
| Unfortunately they stay uncorrelated forever. |
| Counterintuitively, tests show more accuracy - this is because the other box |
| creates lots of 0 errors which get accounted for by the calcultion of the newhead in relu |
| which is apparently worse than not accounting for errors. |
| """ |
| radius = self.w.getVal(c = w, **kargs) |
| return self.domain(original, h.ones(original.size()) * radius, None).checkSizes() |
| |
| def line(self, o1, o2, **kargs): |
| w = self.w.getVal(c = 0, **kargs) |
| return self.domain((o1 + o2) / 2, ((o2 - o1) / 2).abs() + h.ones(o2.size()) * w, None).checkSizes() |
|
|
| def boxBetween(self, o1, o2, *args, **kargs): |
| return self.line(o1, o2, **kargs) |
|
|
| class ZBox(HBox): |
|
|
| def __str__(self): |
| return "ZBox(%s)" % ("" if self.w is None else "w="+str(self.w)) |
|
|
| def Domain(self, *args, **kargs): |
| return ai.Zonotope(*args, **kargs) |
|
|
| class HSwitch(HBox): |
| def __str__(self): |
| return "HSwitch(%s)" % ("" if self.w is None else "w="+str(self.w)) |
| |
| def Domain(self, *args, **kargs): |
| return ai.HybridZonotope(*args, customRelu = ai.creluSwitch, **kargs) |
| |
| class ZSwitch(ZBox): |
|
|
| def __str__(self): |
| return "ZSwitch(%s)" % ("" if self.w is None else "w="+str(self.w)) |
| def Domain(self, *args, **kargs): |
| return ai.Zonotope(*args, customRelu = ai.creluSwitch, **kargs) |
|
|
|
|
| class ZNIPS(ZBox): |
|
|
| def __str__(self): |
| return "ZSwitch(%s)" % ("" if self.w is None else "w="+str(self.w)) |
|
|
| def Domain(self, *args, **kargs): |
| return ai.Zonotope(*args, customRelu = ai.creluNIPS, **kargs) |
| |
| class HSmooth(HBox): |
| def __str__(self): |
| return "HSmooth(%s)" % ("" if self.w is None else "w="+str(self.w)) |
|
|
| def Domain(self, *args, **kargs): |
| return ai.HybridZonotope(*args, customRelu = ai.creluSmooth, **kargs) |
| |
| class HNIPS(HBox): |
| def __str__(self): |
| return "HSmooth(%s)" % ("" if self.w is None else "w="+str(self.w)) |
|
|
| def Domain(self, *args, **kargs): |
| return ai.HybridZonotope(*args, customRelu = ai.creluNIPS, **kargs) |
|
|
| class ZSmooth(ZBox): |
| def __str__(self): |
| return "ZSmooth(%s)" % ("" if self.w is None else "w="+str(self.w)) |
|
|
| def Domain(self, *args, **kargs): |
| return ai.Zonotope(*args, customRelu = ai.creluSmooth, **kargs) |
|
|
|
|
|
|
|
|
|
|
| |
| class HRand(WrapDom): |
| |
| def __init__(self, num_correlated, a = "HSwitch()", **kargs): |
| super(HRand, self).__init__(Box()) |
| self.num_correlated = num_correlated |
| self.dom = eval(a) if type(a) is str else a |
| |
| def Domain(self, d): |
| with torch.no_grad(): |
| out = d.abstractApplyLeaf('stochasticCorrelate', self.num_correlated) |
| out = self.dom.Domain(out.head, out.beta, out.errors) |
| return out |
|
|
| def __str__(self): |
| return "HRand(%s, domain = %s)" % (str(self.num_correlated), str(self.a)) |
|
|