| import future |
| import builtins |
| import past |
| import six |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.optim as optim |
| import torch.autograd |
|
|
| from functools import reduce |
|
|
| try: |
| from . import helpers as h |
| except: |
| import helpers as h |
|
|
|
|
|
|
| def catNonNullErrors(op, ref_errs=None): |
| def doop(er1, er2): |
| erS, erL = (er1, er2) |
| sS, sL = (erS.size()[0], erL.size()[0]) |
|
|
| if sS == sL: |
| return op(erS,erL) |
|
|
| if ref_errs is not None: |
| sz = ref_errs.size()[0] |
| else: |
| sz = min(sS, sL) |
| |
| p1 = op(erS[:sz], erL[:sz]) |
| erSrem = erS[sz:] |
| erLrem = erS[sz:] |
| p2 = op(erSrem, h.zeros(erSrem.shape)) |
| p3 = op(h.zeros(erLrem.shape), erLrem) |
| return torch.cat((p1,p2,p3), dim=0) |
| return doop |
|
|
| def creluBoxy(dom): |
| if dom.errors is None: |
| if dom.beta is None: |
| return dom.new(F.relu(dom.head), None, None) |
| er = dom.beta |
| mx = F.relu(dom.head + er) |
| mn = F.relu(dom.head - er) |
| return dom.new((mn + mx) / 2, (mx - mn) / 2 , None) |
|
|
| aber = torch.abs(dom.errors) |
|
|
| sm = torch.sum(aber, 0) |
|
|
| if not dom.beta is None: |
| sm += dom.beta |
|
|
| mx = dom.head + sm |
| mn = dom.head - sm |
|
|
| should_box = mn.lt(0) * mx.gt(0) |
| gtz = dom.head.gt(0).to_dtype() |
| mx /= 2 |
| newhead = h.ifThenElse(should_box, mx, gtz * dom.head) |
| newbeta = h.ifThenElse(should_box, mx, gtz * (dom.beta if not dom.beta is None else 0)) |
| newerr = (1 - should_box.to_dtype()) * gtz * dom.errors |
|
|
| return dom.new(newhead, newbeta , newerr) |
|
|
|
|
| def creluBoxySound(dom): |
| if dom.errors is None: |
| if dom.beta is None: |
| return dom.new(F.relu(dom.head), None, None) |
| er = dom.beta |
| mx = F.relu(dom.head + er) |
| mn = F.relu(dom.head - er) |
| return dom.new((mn + mx) / 2, (mx - mn) / 2 + 2e-6 , None) |
|
|
| aber = torch.abs(dom.errors) |
|
|
| sm = torch.sum(aber, 0) |
|
|
| if not dom.beta is None: |
| sm += dom.beta |
|
|
| mx = dom.head + sm |
| mn = dom.head - sm |
|
|
| should_box = mn.lt(0) * mx.gt(0) |
| gtz = dom.head.gt(0).to_dtype() |
| mx /= 2 |
| newhead = h.ifThenElse(should_box, mx, gtz * dom.head) |
| newbeta = h.ifThenElse(should_box, mx + 2e-6, gtz * (dom.beta if not dom.beta is None else 0)) |
| newerr = (1 - should_box.to_dtype()) * gtz * dom.errors |
|
|
| return dom.new(newhead, newbeta, newerr) |
|
|
|
|
| def creluSwitch(dom): |
| if dom.errors is None: |
| if dom.beta is None: |
| return dom.new(F.relu(dom.head), None, None) |
| er = dom.beta |
| mx = F.relu(dom.head + er) |
| mn = F.relu(dom.head - er) |
| return dom.new((mn + mx) / 2, (mx - mn) / 2 , None) |
|
|
| aber = torch.abs(dom.errors) |
|
|
| sm = torch.sum(aber, 0) |
|
|
| if not dom.beta is None: |
| sm += dom.beta |
|
|
| mn = dom.head - sm |
| mx = sm |
| mx += dom.head |
|
|
| should_box = mn.lt(0) * mx.gt(0) |
| gtz = dom.head.gt(0) |
|
|
| mn.neg_() |
| should_boxer = mn.gt(mx) |
|
|
| mn /= 2 |
| newhead = h.ifThenElse(should_box, h.ifThenElse(should_boxer, mx / 2, dom.head + mn ), gtz.to_dtype() * dom.head) |
| zbet = dom.beta if not dom.beta is None else 0 |
| newbeta = h.ifThenElse(should_box, h.ifThenElse(should_boxer, mx / 2, mn + zbet), gtz.to_dtype() * zbet) |
| newerr = h.ifThenElseL(should_box, 1 - should_boxer, gtz).to_dtype() * dom.errors |
|
|
| return dom.new(newhead, newbeta , newerr) |
|
|
| def creluSmooth(dom): |
| if dom.errors is None: |
| if dom.beta is None: |
| return dom.new(F.relu(dom.head), None, None) |
| er = dom.beta |
| mx = F.relu(dom.head + er) |
| mn = F.relu(dom.head - er) |
| return dom.new((mn + mx) / 2, (mx - mn) / 2 , None) |
|
|
| aber = torch.abs(dom.errors) |
|
|
| sm = torch.sum(aber, 0) |
|
|
| if not dom.beta is None: |
| sm += dom.beta |
|
|
| mn = dom.head - sm |
| mx = sm |
| mx += dom.head |
|
|
|
|
| nmn = F.relu(-1 * mn) |
|
|
| zbet = (dom.beta if not dom.beta is None else 0) |
| newheadS = dom.head + nmn / 2 |
| newbetaS = zbet + nmn / 2 |
| newerrS = dom.errors |
|
|
| mmx = F.relu(mx) |
|
|
| newheadB = mmx / 2 |
| newbetaB = newheadB |
| newerrB = 0 |
|
|
| eps = 0.0001 |
| t = nmn / (mmx + nmn + eps) |
|
|
| shouldnt_zero = mx.gt(0).to_dtype() |
|
|
| newhead = shouldnt_zero * ( (1 - t) * newheadS + t * newheadB) |
| newbeta = shouldnt_zero * ( (1 - t) * newbetaS + t * newbetaB) |
| newerr = shouldnt_zero * ( (1 - t) * newerrS + t * newerrB) |
|
|
| return dom.new(newhead, newbeta , newerr) |
|
|
|
|
| def creluNIPS(dom): |
| if dom.errors is None: |
| if dom.beta is None: |
| return dom.new(F.relu(dom.head), None, None) |
| er = dom.beta |
| mx = F.relu(dom.head + er) |
| mn = F.relu(dom.head - er) |
| return dom.new((mn + mx) / 2, (mx - mn) / 2 , None) |
| |
| sm = torch.sum(torch.abs(dom.errors), 0) |
|
|
| if not dom.beta is None: |
| sm += dom.beta |
|
|
| mn = dom.head - sm |
| mx = dom.head + sm |
|
|
| mngz = mn >= 0.0 |
|
|
| zs = h.zeros(dom.head.shape) |
|
|
| diff = mx - mn |
|
|
| lam = torch.where((mx > 0) & (diff > 0.0), mx / diff, zs) |
| mu = lam * mn * (-0.5) |
|
|
| betaz = zs if dom.beta is None else dom.beta |
|
|
| newhead = torch.where(mngz, dom.head , lam * dom.head + mu) |
| mngz += diff <= 0.0 |
| newbeta = torch.where(mngz, betaz , lam * betaz + mu ) |
| newerr = torch.where(mngz, dom.errors, lam * dom.errors ) |
| return dom.new(newhead, newbeta, newerr) |
|
|
|
|
|
|
|
|
| class MaxTypes: |
|
|
| @staticmethod |
| def ub(x): |
| return x.ub() |
|
|
| @staticmethod |
| def only_beta(x): |
| return x.beta if x.beta is not None else x.head * 0 |
|
|
| @staticmethod |
| def head_beta(x): |
| return MaxTypes.only_beta(x) + x.head |
|
|
| class HybridZonotope: |
|
|
| def isSafe(self, target): |
| od,_ = torch.min(h.preDomRes(self,target).lb(), 1) |
| return od.gt(0.0).long() |
|
|
| def isPoint(self): |
| return False |
|
|
| def labels(self): |
| target = torch.max(self.ub(), 1)[1] |
| l = list(h.preDomRes(self,target).lb()[0]) |
| return [target.item()] + [ i for i,v in zip(range(len(l)), l) if v <= 0] |
|
|
| def relu(self): |
| return self.customRelu(self) |
| |
| def __init__(self, head, beta, errors, customRelu = creluBoxy, **kargs): |
| self.head = head |
| self.errors = errors |
| self.beta = beta |
| self.customRelu = creluBoxy if customRelu is None else customRelu |
|
|
| def new(self, *args, customRelu = None, **kargs): |
| return self.__class__(*args, **kargs, customRelu = self.customRelu if customRelu is None else customRelu).checkSizes() |
|
|
| def zono_to_hybrid(self, *args, **kargs): |
| return self.new(self.head, self.beta, self.errors, **kargs) |
|
|
| def hybrid_to_zono(self, *args, correlate=True, customRelu = None, **kargs): |
| beta = self.beta |
| errors = self.errors |
| if correlate and beta is not None: |
| batches = beta.shape[0] |
| num_elem = h.product(beta.shape[1:]) |
| ei = h.getEi(batches, num_elem) |
| |
| if len(beta.shape) > 2: |
| ei = ei.contiguous().view(num_elem, *beta.shape) |
| err = ei * beta |
| errors = torch.cat((err, errors), dim=0) if errors is not None else err |
| beta = None |
|
|
| return Zonotope(self.head, beta, errors if errors is not None else (self.beta * 0).unsqueeze(0) , customRelu = self.customRelu if customRelu is None else None) |
|
|
|
|
|
|
| def abstractApplyLeaf(self, foo, *args, **kargs): |
| return getattr(self, foo)(*args, **kargs) |
|
|
| def decorrelate(self, cc_indx_batch_err): |
| if self.errors is None: |
| return self |
|
|
| batch_size = self.head.shape[0] |
| num_error_terms = self.errors.shape[0] |
|
|
| |
|
|
| beta = h.zeros(self.head.shape).to_dtype() if self.beta is None else self.beta |
| errors = h.zeros([0] + list(self.head.shape)).to_dtype() if self.errors is None else self.errors |
|
|
| inds_i = torch.arange(self.head.shape[0], device=h.device).unsqueeze(1).long() |
| errors = errors.to_dtype().permute(1,0, *list(range(len(self.errors.shape)))[2:]) |
| |
| sm = errors.clone() |
| sm[inds_i, cc_indx_batch_err] = 0 |
| |
| beta = beta.to_dtype() + sm.abs().sum(dim=1) |
|
|
| errors = errors[inds_i, cc_indx_batch_err] |
| errors = errors.permute(1,0, *list(range(len(self.errors.shape)))[2:]).contiguous() |
| return self.new(self.head, beta, errors) |
| |
| def dummyDecorrelate(self, num_decorrelate): |
| if num_decorrelate == 0 or self.errors is None: |
| return self |
| elif num_decorrelate >= self.errors.shape[0]: |
| beta = self.beta |
| if self.errors is not None: |
| errs = self.errors.abs().sum(dim=0) |
| if beta is None: |
| beta = errs |
| else: |
| beta += errs |
| return self.new(self.head, beta, None) |
| return None |
|
|
| def stochasticDecorrelate(self, num_decorrelate, choices = None, num_to_keep=False): |
| dummy = self.dummyDecorrelate(num_decorrelate) |
| if dummy is not None: |
| return dummy |
| num_error_terms = self.errors.shape[0] |
| batch_size = self.head.shape[0] |
|
|
| ucc_mask = h.ones([batch_size, self.errors.shape[0]]).long() |
| cc_indx_batch_err = h.cudify(torch.multinomial(ucc_mask.to_dtype(), num_decorrelate if num_to_keep else num_error_terms - num_decorrelate, replacement=False)) if choices is None else choices |
| return self.decorrelate(cc_indx_batch_err) |
|
|
| def decorrelateMin(self, num_decorrelate, num_to_keep=False): |
| dummy = self.dummyDecorrelate(num_decorrelate) |
| if dummy is not None: |
| return dummy |
|
|
| num_error_terms = self.errors.shape[0] |
| batch_size = self.head.shape[0] |
|
|
| error_sum_b_e = self.errors.abs().view(self.errors.shape[0], batch_size, -1).sum(dim=2).permute(1,0) |
| cc_indx_batch_err = error_sum_b_e.topk(num_decorrelate if num_to_keep else num_error_terms - num_decorrelate)[1] |
| return self.decorrelate(cc_indx_batch_err) |
| |
| def correlate(self, cc_indx_batch_beta): |
| num_correlate = h.product(cc_indx_batch_beta.shape[1:]) |
| |
| beta = h.zeros(self.head.shape).to_dtype() if self.beta is None else self.beta |
| errors = h.zeros([0] + list(self.head.shape)).to_dtype() if self.errors is None else self.errors |
|
|
| batch_size = beta.shape[0] |
| new_errors = h.zeros([num_correlate] + list(self.head.shape)).to_dtype() |
| |
| inds_i = torch.arange(batch_size, device=h.device).unsqueeze(1).long() |
|
|
| nc = torch.arange(num_correlate, device=h.device).unsqueeze(1).long() |
|
|
| new_errors = new_errors.permute(1,0, *list(range(len(new_errors.shape)))[2:]).contiguous().view(batch_size, num_correlate, -1) |
| new_errors[inds_i, nc.unsqueeze(0).expand([batch_size]+list(nc.shape)).squeeze(2), cc_indx_batch_beta] = beta.view(batch_size,-1)[inds_i, cc_indx_batch_beta] |
|
|
| new_errors = new_errors.permute(1,0, *list(range(len(new_errors.shape)))[2:]).contiguous().view(num_correlate, batch_size, *beta.shape[1:]) |
| errors = torch.cat((errors, new_errors), dim=0) |
| |
| beta.view(batch_size, -1)[inds_i, cc_indx_batch_beta] = 0 |
| |
| return self.new(self.head, beta, errors) |
|
|
| def stochasticCorrelate(self, num_correlate, choices = None): |
| if num_correlate == 0: |
| return self |
|
|
| domshape = self.head.shape |
| batch_size = domshape[0] |
| num_pixs = h.product(domshape[1:]) |
| num_correlate = min(num_correlate, num_pixs) |
| ucc_mask = h.ones([batch_size, num_pixs ]).long() |
|
|
| cc_indx_batch_beta = h.cudify(torch.multinomial(ucc_mask.to_dtype(), num_correlate, replacement=False)) if choices is None else choices |
| return self.correlate(cc_indx_batch_beta) |
|
|
|
|
| def correlateMaxK(self, num_correlate): |
| if num_correlate == 0: |
| return self |
| |
| domshape = self.head.shape |
| batch_size = domshape[0] |
| num_pixs = h.product(domshape[1:]) |
| num_correlate = min(num_correlate, num_pixs) |
|
|
| concrete_max_image = self.ub().view(batch_size, -1) |
|
|
| cc_indx_batch_beta = concrete_max_image.topk(num_correlate)[1] |
| return self.correlate(cc_indx_batch_beta) |
|
|
| def correlateMaxPool(self, *args, max_type = MaxTypes.ub , max_pool = F.max_pool2d, **kargs): |
| domshape = self.head.shape |
| batch_size = domshape[0] |
| num_pixs = h.product(domshape[1:]) |
|
|
| concrete_max_image = max_type(self) |
|
|
| cc_indx_batch_beta = max_pool(concrete_max_image, *args, return_indices=True, **kargs)[1].view(batch_size, -1) |
|
|
| return self.correlate(cc_indx_batch_beta) |
|
|
| def checkSizes(self): |
| if not self.errors is None: |
| if not self.errors.size()[1:] == self.head.size(): |
| raise Exception("Such bad sizes on error:", self.errors.shape, " head:", self.head.shape) |
| if torch.isnan(self.errors).any(): |
| raise Exception("Such nan in errors") |
| if not self.beta is None: |
| if not self.beta.size() == self.head.size(): |
| raise Exception("Such bad sizes on beta") |
|
|
| if torch.isnan(self.beta).any(): |
| raise Exception("Such nan in errors") |
| if self.beta.lt(0.0).any(): |
| self.beta = self.beta.abs() |
| |
| return self |
|
|
| def __mul__(self, flt): |
| return self.new(self.head * flt, None if self.beta is None else self.beta * abs(flt), None if self.errors is None else self.errors * flt) |
| |
| def __truediv__(self, flt): |
| flt = 1. / flt |
| return self.new(self.head * flt, None if self.beta is None else self.beta * abs(flt), None if self.errors is None else self.errors * flt) |
|
|
| def __add__(self, other): |
| if isinstance(other, HybridZonotope): |
| return self.new(self.head + other.head, h.msum(self.beta, other.beta, lambda a,b: a + b), h.msum(self.errors, other.errors, catNonNullErrors(lambda a,b: a + b))) |
| else: |
| |
| return self.new(self.head + other, self.beta, self.errors) |
|
|
| def addPar(self, a, b): |
| return self.new(a.head + b.head, h.msum(a.beta, b.beta, lambda a,b: a + b), h.msum(a.errors, b.errors, catNonNullErrors(lambda a,b: a + b, self.errors))) |
|
|
| def __sub__(self, other): |
| if isinstance(other, HybridZonotope): |
| return self.new(self.head - other.head |
| , h.msum(self.beta, other.beta, lambda a,b: a + b) |
| , h.msum(self.errors, None if other.errors is None else -other.errors, catNonNullErrors(lambda a,b: a + b))) |
| else: |
| |
| return self.new(self.head - other, self.beta, self.errors) |
|
|
| def bmm(self, other): |
| hd = self.head.bmm(other) |
| bet = None if self.beta is None else self.beta.bmm(other.abs()) |
|
|
| if self.errors is None: |
| er = None |
| else: |
| er = self.errors.matmul(other) |
| return self.new(hd, bet, er) |
|
|
|
|
| def getBeta(self): |
| return self.head * 0 if self.beta is None else self.beta |
|
|
| def getErrors(self): |
| return (self.head * 0).unsqueeze(0) if self.beta is None else self.errors |
|
|
| def merge(self, other, ref = None): |
| s_beta = self.getBeta() |
|
|
| sbox_u = self.head + s_beta |
| sbox_l = self.head - s_beta |
| o_u = other.ub() |
| o_l = other.lb() |
| o_in_s = (o_u <= sbox_u) & (o_l >= sbox_l) |
|
|
| s_err_mx = self.errors.abs().sum(dim=0) |
|
|
| if not isinstance(other, HybridZonotope): |
| new_head = (self.head + other.center()) / 2 |
| new_beta = torch.max(sbox_u + s_err_mx,o_u) - new_head |
| return self.new(torch.where(o_in_s, self.head, new_head), torch.where(o_in_s, self.beta,new_beta), o_in_s.float() * self.errors) |
| |
| |
| s_u = sbox_u + s_err_mx |
| s_l = sbox_l - s_err_mx |
|
|
| obox_u = o_u - other.head |
| obox_l = o_l + other.head |
|
|
| s_in_o = (s_u <= obox_u) & (s_l >= obox_l) |
| |
| |
| new_head = (self.head + other.center()) / 2 |
| new_beta = torch.max(sbox_u + self.getErrors().abs().sum(dim=0),o_u) - new_head |
|
|
| return self.new(torch.where(o_in_s, self.head, torch.where(s_in_o, other.head, new_head)) |
| , torch.where(o_in_s, s_beta,torch.where(s_in_o, other.getBeta(), new_beta)) |
| , h.msum(o_in_s.float() * self.errors, s_in_o.float() * other.errors, catNonNullErrors(lambda a,b: a + b, ref_errs = ref.errors if ref is not None else ref))) |
| |
|
|
| def conv(self, conv, weight, bias = None, **kargs): |
| h = self.errors |
| inter = h if h is None else h.view(-1, *h.size()[2:]) |
| hd = conv(self.head, weight, bias=bias, **kargs) |
| res = h if h is None else conv(inter, weight, bias=None, **kargs) |
|
|
| return self.new( hd |
| , None if self.beta is None else conv(self.beta, weight.abs(), bias = None, **kargs) |
| , h if h is None else res.view(h.size()[0], h.size()[1], *res.size()[1:])) |
| |
|
|
| def conv1d(self, *args, **kargs): |
| return self.conv(lambda x, *args, **kargs: x.conv1d(*args,**kargs), *args, **kargs) |
| |
| def conv2d(self, *args, **kargs): |
| return self.conv(lambda x, *args, **kargs: x.conv2d(*args,**kargs), *args, **kargs) |
|
|
| def conv3d(self, *args, **kargs): |
| return self.conv(lambda x, *args, **kargs: x.conv3d(*args,**kargs), *args, **kargs) |
|
|
| def conv_transpose1d(self, *args, **kargs): |
| return self.conv(lambda x, *args, **kargs: x.conv_transpose1d(*args,**kargs), *args, **kargs) |
| |
| def conv_transpose2d(self, *args, **kargs): |
| return self.conv(lambda x, *args, **kargs: x.conv_transpose2d(*args,**kargs), *args, **kargs) |
|
|
| def conv_transpose3d(self, *args, **kargs): |
| return self.conv(lambda x, *args, **kargs: x.conv_transpose3d(*args,**kargs), *args, **kargs) |
| |
| def matmul(self, other): |
| return self.new(self.head.matmul(other), None if self.beta is None else self.beta.matmul(other.abs()), None if self.errors is None else self.errors.matmul(other)) |
|
|
| def unsqueeze(self, i): |
| return self.new(self.head.unsqueeze(i), None if self.beta is None else self.beta.unsqueeze(i), None if self.errors is None else self.errors.unsqueeze(i + 1)) |
|
|
| def squeeze(self, dim): |
| return self.new(self.head.squeeze(dim), |
| None if self.beta is None else self.beta.squeeze(dim), |
| None if self.errors is None else self.errors.squeeze(dim + 1 if dim >= 0 else dim)) |
|
|
| def double(self): |
| return self.new(self.head.double(), self.beta.double() if self.beta is not None else None, self.errors.double() if self.errors is not None else None) |
|
|
| def float(self): |
| return self.new(self.head.float(), self.beta.float() if self.beta is not None else None, self.errors.float() if self.errors is not None else None) |
|
|
| def to_dtype(self): |
| return self.new(self.head.to_dtype(), self.beta.to_dtype() if self.beta is not None else None, self.errors.to_dtype() if self.errors is not None else None) |
| |
| def sum(self, dim=1): |
| return self.new(torch.sum(self.head,dim=dim), None if self.beta is None else torch.sum(self.beta,dim=dim), None if self.errors is None else torch.sum(self.errors, dim= dim + 1 if dim >= 0 else dim)) |
|
|
| def view(self,*newshape): |
| return self.new(self.head.view(*newshape), |
| None if self.beta is None else self.beta.view(*newshape), |
| None if self.errors is None else self.errors.view(self.errors.size()[0], *newshape)) |
|
|
| def gather(self,dim, index): |
| return self.new(self.head.gather(dim, index), |
| None if self.beta is None else self.beta.gather(dim, index), |
| None if self.errors is None else self.errors.gather(dim + 1, index.expand([self.errors.size()[0]] + list(index.size())))) |
| |
| def concretize(self): |
| if self.errors is None: |
| return self |
|
|
| return self.new(self.head, torch.sum(self.concreteErrors().abs(),0), None) |
| |
| def cat(self,other, dim=0): |
| return self.new(self.head.cat(other.head, dim = dim), |
| h.msum(other.beta, self.beta, lambda a,b: a.cat(b, dim = dim)), |
| h.msum(self.errors, other.errors, catNonNullErrors(lambda a,b: a.cat(b, dim+1)))) |
|
|
|
|
| def split(self, split_size, dim = 0): |
| heads = list(self.head.split(split_size, dim)) |
| betas = list(self.beta.split(split_size, dim)) if not self.beta is None else None |
| errorss = list(self.errors.split(split_size, dim + 1)) if not self.errors is None else None |
| |
| def makeFromI(i): |
| return self.new( heads[i], |
| None if betas is None else betas[i], |
| None if errorss is None else errorss[i]) |
| return tuple(makeFromI(i) for i in range(len(heads))) |
|
|
| |
| |
| def concreteErrors(self): |
| if self.beta is None and self.errors is None: |
| raise Exception("shouldn't have both beta and errors be none") |
| if self.errors is None: |
| return self.beta.unsqueeze(0) |
| if self.beta is None: |
| return self.errors |
| return torch.cat([self.beta.unsqueeze(0),self.errors], dim=0) |
|
|
|
|
| def applyMonotone(self, foo, *args, **kargs): |
| if self.beta is None and self.errors is None: |
| return self.new(foo(self.head), None , None) |
|
|
| beta = self.concreteErrors().abs().sum(dim=0) |
|
|
| tp = foo(self.head + beta, *args, **kargs) |
| bt = foo(self.head - beta, *args, **kargs) |
|
|
| new_hybrid = self.new((tp + bt) / 2, (tp - bt) / 2 , None) |
|
|
|
|
| if self.errors is not None: |
| return new_hybrid.correlateMaxK(self.errors.shape[0]) |
| return new_hybrid |
|
|
| def avg_pool2d(self, *args, **kargs): |
| nhead = F.avg_pool2d(self.head, *args, **kargs) |
| return self.new(nhead, |
| None if self.beta is None else F.avg_pool2d(self.beta, *args, **kargs), |
| None if self.errors is None else F.avg_pool2d(self.errors.view(-1, *self.head.shape[1:]), *args, **kargs).view(-1,*nhead.shape)) |
|
|
| def adaptive_avg_pool2d(self, *args, **kargs): |
| nhead = F.adaptive_avg_pool2d(self.head, *args, **kargs) |
| return self.new(nhead, |
| None if self.beta is None else F.adaptive_avg_pool2d(self.beta, *args, **kargs), |
| None if self.errors is None else F.adaptive_avg_pool2d(self.errors.view(-1, *self.head.shape[1:]), *args, **kargs).view(-1,*nhead.shape)) |
|
|
| def elu(self): |
| return self.applyMonotone(F.elu) |
|
|
| def selu(self): |
| return self.applyMonotone(F.selu) |
|
|
| def sigm(self): |
| return self.applyMonotone(F.sigmoid) |
|
|
| def softplus(self): |
| if self.errors is None: |
| if self.beta is None: |
| return self.new(F.softplus(self.head), None , None) |
| tp = F.softplus(self.head + self.beta) |
| bt = F.softplus(self.head - self.beta) |
| return self.new((tp + bt) / 2, (tp - bt) / 2 , None) |
|
|
| errors = self.concreteErrors() |
| o = h.ones(self.head.size()) |
|
|
| def sp(hd): |
| return F.softplus(hd) |
| def spp(hd): |
| ehd = torch.exp(hd) |
| return ehd.div(ehd + o) |
| def sppp(hd): |
| ehd = torch.exp(hd) |
| md = ehd + o |
| return ehd.div(md.mul(md)) |
|
|
| fa = sp(self.head) |
| fpa = spp(self.head) |
|
|
| a = self.head |
|
|
| k = torch.sum(errors.abs(), 0) |
|
|
| def evalG(r): |
| return r.mul(r).mul(sppp(a + r)) |
|
|
| m = torch.max(evalG(h.zeros(k.size())), torch.max(evalG(k), evalG(-k))) |
| m = h.ifThenElse( a.abs().lt(k), torch.max(m, torch.max(evalG(a), evalG(-a))), m) |
| m /= 2 |
| |
| return self.new(fa, m if self.beta is None else m + self.beta.mul(fpa), None if self.errors is None else self.errors.mul(fpa)) |
|
|
| def center(self): |
| return self.head |
|
|
| def vanillaTensorPart(self): |
| return self.head |
|
|
| def lb(self): |
| return self.head - self.concreteErrors().abs().sum(dim=0) |
|
|
| def ub(self): |
| return self.head + self.concreteErrors().abs().sum(dim=0) |
|
|
| def size(self): |
| return self.head.size() |
|
|
| def diameter(self): |
| abal = torch.abs(self.concreteErrors()).transpose(0,1) |
| return abal.sum(1).sum(1) |
|
|
| def loss(self, target, **args): |
| r = -h.preDomRes(self, target).lb() |
| return F.softplus(r.max(1)[0]) |
|
|
| def deep_loss(self, act = F.relu, *args, **kargs): |
| batch_size = self.head.shape[0] |
| inds = torch.arange(batch_size, device=h.device).unsqueeze(1).long() |
|
|
| def dl(l,u): |
| ls, lsi = torch.sort(l, dim=1) |
| ls_u = u[inds, lsi] |
|
|
| def slidingMax(a): |
| k = a.shape[1] |
| ml = a.min(dim=1)[0].unsqueeze(1) |
|
|
| inp = torch.cat((h.zeros([batch_size, k]), a - ml), dim=1) |
| mpl = F.max_pool1d(inp.unsqueeze(1) , kernel_size = k, stride=1, padding = 0, return_indices=False).squeeze(1) |
| return mpl[:,:-1] + ml |
| |
| return act(slidingMax(ls_u) - ls).sum(dim=1) |
|
|
| l = self.lb().view(batch_size, -1) |
| u = self.ub().view(batch_size, -1) |
| return ( dl(l,u) + dl(-u,-l) ) / (2 * l.shape[1]) |
|
|
|
|
|
|
| class Zonotope(HybridZonotope): |
| def applySuper(self, ret): |
| batches = ret.head.size()[0] |
| num_elem = h.product(ret.head.size()[1:]) |
| ei = h.getEi(batches, num_elem) |
|
|
| if len(ret.head.size()) > 2: |
| ei = ei.contiguous().view(num_elem, *ret.head.size()) |
|
|
| ret.errors = torch.cat( (ret.errors, ei * ret.beta) ) if not ret.beta is None else ret.errors |
| ret.beta = None |
| return ret.checkSizes() |
|
|
| def zono_to_hybrid(self, *args, customRelu = None, **kargs): |
| return HybridZonotope(self.head, self.beta, self.errors, customRelu = self.customRelu if customRelu is None else customRelu) |
|
|
| def hybrid_to_zono(self, *args, **kargs): |
| return self.new(self.head, self.beta, self.errors, **kargs) |
|
|
| def applyMonotone(self, *args, **kargs): |
| return self.applySuper(super(Zonotope,self).applyMonotone(*args, **kargs)) |
|
|
| def softplus(self): |
| return self.applySuper(super(Zonotope,self).softplus()) |
|
|
| def relu(self): |
| return self.applySuper(super(Zonotope,self).relu()) |
|
|
| def splitRelu(self, *args, **kargs): |
| return [self.applySuper(a) for a in super(Zonotope, self).splitRelu(*args, **kargs)] |
|
|
|
|
| def mysign(x): |
| e = x.eq(0).to_dtype() |
| r = x.sign().to_dtype() |
| return r + e |
|
|
| def mulIfEq(grad,out,target): |
| pred = out.max(1, keepdim=True)[1] |
| is_eq = pred.eq(target.view_as(pred)).to_dtype() |
| is_eq = is_eq.view([-1] + [1 for _ in grad.size()[1:]]).expand_as(grad) |
| return is_eq |
| |
|
|
| def stdLoss(out, target): |
| if torch.__version__[0] == "0": |
| return F.cross_entropy(out, target, reduce = False) |
| else: |
| return F.cross_entropy(out, target, reduction='none') |
|
|
|
|
|
|
| class ListDomain(object): |
|
|
| def __init__(self, al, *args, **kargs): |
| self.al = list(al) |
|
|
| def new(self, *args, **kargs): |
| return self.__class__(*args, **kargs) |
|
|
| def isSafe(self,*args,**kargs): |
| raise "Domain Not Suitable For Testing" |
|
|
| def labels(self): |
| raise "Domain Not Suitable For Testing" |
|
|
| def isPoint(self): |
| return all(a.isPoint() for a in self.al) |
|
|
| def __mul__(self, flt): |
| return self.new(a.__mul__(flt) for a in self.al) |
|
|
| def __truediv__(self, flt): |
| return self.new(a.__truediv__(flt) for a in self.al) |
|
|
| def __add__(self, other): |
| if isinstance(other, ListDomain): |
| return self.new(a.__add__(o) for a,o in zip(self.al, other.al)) |
| else: |
| return self.new(a.__add__(other) for a in self.al) |
|
|
| def merge(self, other, ref = None): |
| if ref is None: |
| return self.new(a.merge(o) for a,o in zip(self.al,other.al) ) |
| return self.new(a.merge(o, ref = r) for a,o,r in zip(self.al,other.al, ref.al)) |
|
|
| def addPar(self, a, b): |
| return self.new(s.addPar(av,bv) for s,av,bv in zip(self.al, a.al, b.al)) |
|
|
| def __sub__(self, other): |
| if isinstance(other, ListDomain): |
| return self.new(a.__sub__(o) for a,o in zip(self.al, other.al)) |
| else: |
| return self.new(a.__sub__(other) for a in self.al) |
|
|
| def abstractApplyLeaf(self, *args, **kargs): |
| return self.new(a.abstractApplyLeaf(*args, **kargs) for a in self.al) |
|
|
| def bmm(self, other): |
| return self.new(a.bmm(other) for a in self.al) |
|
|
| def matmul(self, other): |
| return self.new(a.matmul(other) for a in self.al) |
|
|
| def conv(self, *args, **kargs): |
| return self.new(a.conv(*args, **kargs) for a in self.al) |
|
|
| def conv1d(self, *args, **kargs): |
| return self.new(a.conv1d(*args, **kargs) for a in self.al) |
|
|
| def conv2d(self, *args, **kargs): |
| return self.new(a.conv2d(*args, **kargs) for a in self.al) |
|
|
| def conv3d(self, *args, **kargs): |
| return self.new(a.conv3d(*args, **kargs) for a in self.al) |
|
|
| def max_pool2d(self, *args, **kargs): |
| return self.new(a.max_pool2d(*args, **kargs) for a in self.al) |
|
|
| def avg_pool2d(self, *args, **kargs): |
| return self.new(a.avg_pool2d(*args, **kargs) for a in self.al) |
|
|
| def adaptive_avg_pool2d(self, *args, **kargs): |
| return self.new(a.adaptive_avg_pool2d(*args, **kargs) for a in self.al) |
|
|
| def unsqueeze(self, *args, **kargs): |
| return self.new(a.unsqueeze(*args, **kargs) for a in self.al) |
|
|
| def squeeze(self, *args, **kargs): |
| return self.new(a.squeeze(*args, **kargs) for a in self.al) |
|
|
| def view(self, *args, **kargs): |
| return self.new(a.view(*args, **kargs) for a in self.al) |
|
|
| def gather(self, *args, **kargs): |
| return self.new(a.gather(*args, **kargs) for a in self.al) |
|
|
| def sum(self, *args, **kargs): |
| return self.new(a.sum(*args,**kargs) for a in self.al) |
|
|
| def double(self): |
| return self.new(a.double() for a in self.al) |
|
|
| def float(self): |
| return self.new(a.float() for a in self.al) |
|
|
| def to_dtype(self): |
| return self.new(a.to_dtype() for a in self.al) |
|
|
| def vanillaTensorPart(self): |
| return self.al[0].vanillaTensorPart() |
|
|
| def center(self): |
| return self.new(a.center() for a in self.al) |
|
|
| def ub(self): |
| return self.new(a.ub() for a in self.al) |
|
|
| def lb(self): |
| return self.new(a.lb() for a in self.al) |
|
|
| def relu(self): |
| return self.new(a.relu() for a in self.al) |
|
|
| def splitRelu(self, *args, **kargs): |
| return self.new(a.splitRelu(*args, **kargs) for a in self.al) |
|
|
| def softplus(self): |
| return self.new(a.softplus() for a in self.al) |
|
|
| def elu(self): |
| return self.new(a.elu() for a in self.al) |
|
|
| def selu(self): |
| return self.new(a.selu() for a in self.al) |
|
|
| def sigm(self): |
| return self.new(a.sigm() for a in self.al) |
|
|
| def cat(self, other, *args, **kargs): |
| return self.new(a.cat(o, *args, **kargs) for a,o in zip(self.al, other.al)) |
|
|
|
|
| def split(self, *args, **kargs): |
| return [self.new(*z) for z in zip(a.split(*args, **kargs) for a in self.al)] |
|
|
| def size(self): |
| return self.al[0].size() |
|
|
| def loss(self, *args, **kargs): |
| return sum(a.loss(*args, **kargs) for a in self.al) |
|
|
| def deep_loss(self, *args, **kargs): |
| return sum(a.deep_loss(*args, **kargs) for a in self.al) |
|
|
| def checkSizes(self): |
| for a in self.al: |
| a.checkSizes() |
| return self |
|
|
|
|
| class TaggedDomain(object): |
|
|
|
|
| def __init__(self, a, tag = None): |
| self.tag = tag |
| self.a = a |
|
|
| def isSafe(self,*args,**kargs): |
| return self.a.isSafe(*args, **kargs) |
|
|
| def isPoint(self): |
| return self.a.isPoint() |
|
|
| def labels(self): |
| raise "Domain Not Suitable For Testing" |
|
|
| def __mul__(self, flt): |
| return TaggedDomain(self.a.__mul__(flt), self.tag) |
|
|
| def __truediv__(self, flt): |
| return TaggedDomain(self.a.__truediv__(flt), self.tag) |
|
|
| def __add__(self, other): |
| if isinstance(other, TaggedDomain): |
| return TaggedDomain(self.a.__add__(other.a), self.tag) |
| else: |
| return TaggedDomain(self.a.__add__(other), self.tag) |
|
|
| def addPar(self, a,b): |
| return TaggedDomain(self.a.addPar(a.a, b.a), self.tag) |
|
|
| def __sub__(self, other): |
| if isinstance(other, TaggedDomain): |
| return TaggedDomain(self.a.__sub__(other.a), self.tag) |
| else: |
| return TaggedDomain(self.a.__sub__(other), self.tag) |
|
|
| def bmm(self, other): |
| return TaggedDomain(self.a.bmm(other), self.tag) |
|
|
| def matmul(self, other): |
| return TaggedDomain(self.a.matmul(other), self.tag) |
|
|
| def conv(self, *args, **kargs): |
| return TaggedDomain(self.a.conv(*args, **kargs) , self.tag) |
|
|
| def conv1d(self, *args, **kargs): |
| return TaggedDomain(self.a.conv1d(*args, **kargs), self.tag) |
|
|
| def conv2d(self, *args, **kargs): |
| return TaggedDomain(self.a.conv2d(*args, **kargs), self.tag) |
|
|
| def conv3d(self, *args, **kargs): |
| return TaggedDomain(self.a.conv3d(*args, **kargs), self.tag) |
|
|
| def max_pool2d(self, *args, **kargs): |
| return TaggedDomain(self.a.max_pool2d(*args, **kargs), self.tag) |
|
|
| def avg_pool2d(self, *args, **kargs): |
| return TaggedDomain(self.a.avg_pool2d(*args, **kargs), self.tag) |
|
|
| def adaptive_avg_pool2d(self, *args, **kargs): |
| return TaggedDomain(self.a.adaptive_avg_pool2d(*args, **kargs), self.tag) |
|
|
|
|
| def unsqueeze(self, *args, **kargs): |
| return TaggedDomain(self.a.unsqueeze(*args, **kargs), self.tag) |
|
|
| def squeeze(self, *args, **kargs): |
| return TaggedDomain(self.a.squeeze(*args, **kargs), self.tag) |
|
|
| def abstractApplyLeaf(self, *args, **kargs): |
| return TaggedDomain(self.a.abstractApplyLeaf(*args, **kargs), self.tag) |
|
|
| def view(self, *args, **kargs): |
| return TaggedDomain(self.a.view(*args, **kargs), self.tag) |
|
|
| def gather(self, *args, **kargs): |
| return TaggedDomain(self.a.gather(*args, **kargs), self.tag) |
|
|
| def sum(self, *args, **kargs): |
| return TaggedDomain(self.a.sum(*args,**kargs), self.tag) |
|
|
| def double(self): |
| return TaggedDomain(self.a.double(), self.tag) |
|
|
| def float(self): |
| return TaggedDomain(self.a.float(), self.tag) |
|
|
| def to_dtype(self): |
| return TaggedDomain(self.a.to_dtype(), self.tag) |
|
|
| def vanillaTensorPart(self): |
| return self.a.vanillaTensorPart() |
|
|
| def center(self): |
| return TaggedDomain(self.a.center(), self.tag) |
|
|
| def ub(self): |
| return TaggedDomain(self.a.ub(), self.tag) |
|
|
| def lb(self): |
| return TaggedDomain(self.a.lb(), self.tag) |
|
|
| def relu(self): |
| return TaggedDomain(self.a.relu(), self.tag) |
|
|
| def splitRelu(self, *args, **kargs): |
| return TaggedDomain(self.a.splitRelu(*args, **kargs), self.tag) |
|
|
| def diameter(self): |
| return self.a.diameter() |
|
|
| def softplus(self): |
| return TaggedDomain(self.a.softplus(), self.tag) |
|
|
| def elu(self): |
| return TaggedDomain(self.a.elu(), self.tag) |
|
|
| def selu(self): |
| return TaggedDomain(self.a.selu(), self.tag) |
|
|
| def sigm(self): |
| return TaggedDomain(self.a.sigm(), self.tag) |
|
|
|
|
| def cat(self, other, *args, **kargs): |
| return TaggedDomain(self.a.cat(other.a, *args, **kargs), self.tag) |
|
|
| def split(self, *args, **kargs): |
| return [TaggedDomain(z, self.tag) for z in self.a.split(*args, **kargs)] |
|
|
| def size(self): |
| |
| return self.a.size() |
|
|
| def loss(self, *args, **kargs): |
| return self.tag.loss(self.a, *args, **kargs) |
|
|
| def deep_loss(self, *args, **kargs): |
| return self.a.deep_loss(*args, **kargs) |
|
|
| def checkSizes(self): |
| self.a.checkSizes() |
| return self |
|
|
| def merge(self, other, ref = None): |
| return TaggedDomain(self.a.merge(other.a, ref = None if ref is None else ref.a), self.tag) |
|
|