-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathconvex.py
More file actions
126 lines (94 loc) · 3.03 KB
/
convex.py
File metadata and controls
126 lines (94 loc) · 3.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import numpy as np
from numpy.linalg import norm
from sklearn.isotonic import isotonic_regression
from ya_glm.opt.base import Func
class Constraint(Func):
def _eval(self, x):
return 0
@property
def is_smooth(self):
return False
class Positive(Constraint):
def _prox(self, x, step=1):
p = np.zeros_like(x)
pos_mask = x > 0
p[pos_mask] = x[pos_mask]
return p
@property
def is_proximable(self):
return True
class LinearEquality(Constraint):
# credited to PyUNLocBoX
# https://github.com/epfl-lts2/pyunlocbox/
def __init__(self, A, b):
self.A = A
self.b = b
self.pinvA = np.linalg.pinv(A)
def _prox(self, x, step=1):
residue = self.A@x - self.b
sol = x - self.pinvA @ residue
return sol
@property
def is_proximable(self):
return True
class Simplex(Constraint):
def __init__(self, mult=1):
self.mult = mult
def _prox(self, x, step=1):
# TODO: z is what I think it is right?
p = project_simplex(x.reshape(-1), z=self.mult)
return p.reshape(x.shape)
@property
def is_proximable(self):
return True
class L1Ball(Constraint):
def __init__(self, mult=1):
self.mult = mult
def _prox(self, x, step=1):
p = project_l1_ball(x.reshape(-1), z=self.mult)
return p.reshape(x.shape)
@property
def is_proximable(self):
return True
# See https://gist.github.com/mblondel/6f3b7aaad90606b98f71
# for more algorithms.
def project_simplex(v, z=1):
# z is what the entries need to add up to, e.g. z=1 for probability simplex
if np.sum(v) <= z: # don't we want the simplex to mean sum == z not sum <= z?
return v # also this doesn't work when v has, say, all negative entries
n_features = v.shape[0]
u = np.sort(v)[::-1]
cssv = np.cumsum(u) - z
ind = np.arange(n_features) + 1
cond = u - cssv / ind > 0
rho = ind[cond][-1]
theta = cssv[cond][-1] / rho
w = np.maximum(v - theta, 0)
return w
def project_l1_ball(v, z=1):
return np.sign(v) * project_simplex(np.abs(v), z)
class L2Ball(Constraint):
def __init__(self, mult=1):
assert mult > 0
self.mult = mult
def _prox(self, x, step=1):
return x / np.max([norm(x)/self.mult, 1])
@property
def is_proximable(self):
return True
class Isotonic(Constraint):
"""Constraint for x1 <= ... <= xn or
x1 >= ... >= xn """
# TODO: allow for general isotonic regression
# where the order relations are a simple directed
# graph. For an algorithm see Nemeth and Nemeth, "How to project onto an
# isotone projection cone", JLLA 2010
def __init__(self, increasing=True):
self.increasing = increasing
def _prox(self, x, step=1):
# computes the projection of x onto the monotone cone
# using the PAVA algorithm
return isotonic_regression(x, increasing=self.increasing)
@property
def is_proximable(self):
return True