9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130 | class TabularConstraintSet:
"""Projects tabular attack candidates back to the valid feature domain."""
def __init__(self, metadata: TabularAdversarialMetadata):
# The metadata is dataset-level and immutable; derived tensors are cached per device/dtype.
self.metadata = metadata
self._tensor_cache: dict[tuple[torch.device, torch.dtype], dict[str, torch.Tensor]] = {}
def tensors(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
# Masks and bounds are reused in every constrained PGD step, so build them once per placement.
key = (x.device, x.dtype)
cached = self._tensor_cache.get(key)
if cached is not None:
return cached
# Masks have shape (1, n_features), which broadcasts over the batch dimension.
cached = {
"continuous": self._feature_type_mask(x, CONTINUOUS),
"integer": self._feature_type_mask(x, INTEGER),
"categorical": self._feature_type_mask(x, CATEGORICAL),
"min": torch.tensor(self.metadata.feature_min_norm, dtype=x.dtype, device=x.device).view(1, -1),
"max": torch.tensor(self.metadata.feature_max_norm, dtype=x.dtype, device=x.device).view(1, -1),
}
cached["numeric"] = cached["continuous"] | cached["integer"]
cached["perturbable"] = cached["numeric"] | cached["categorical"]
cached["integer_step"] = self._integer_steps(cached["min"])
self._tensor_cache[key] = cached
return cached
def perturbable_mask(self, x: torch.Tensor) -> torch.Tensor:
# Used by the attack step to avoid moving immutable features in the first place.
return self.tensors(x)["perturbable"]
def project(self, x_candidate: torch.Tensor, x_clean: torch.Tensor, epsilon: float) -> torch.Tensor:
# Clamp numeric features, round integers, restore immutable features and fix one-hot groups.
tensors = self.tensors(x_clean)
lower, upper = self._bounds(x_clean, epsilon, tensors)
# First force every value into its valid interval, then apply type-specific fixes.
x_projected = torch.max(torch.min(x_candidate, upper), lower)
x_projected = self._project_integer_features(x_projected, x_clean, lower, upper, tensors)
x_projected = self.project_categorical_groups(x_projected)
# Immutable features are copied back from the original clean sample as the final guardrail.
return torch.where(tensors["perturbable"], x_projected, x_clean)
def categorical_gradient_step(self, x_candidate: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
if not self.metadata.categorical_groups:
return x_candidate
# One-hot columns are discrete: instead of adding a fractional gradient,
# activate the category whose gradient most increases the adversarial loss.
x_stepped = x_candidate.clone()
for group in self.metadata.categorical_groups:
group_tensor = torch.tensor(group, dtype=torch.long, device=x_candidate.device)
selected = grad.index_select(1, group_tensor).argmax(dim=1)
x_stepped[:, group_tensor] = F.one_hot(selected, num_classes=len(group)).to(dtype=x_candidate.dtype)
return x_stepped
def project_categorical_groups(self, x_candidate: torch.Tensor) -> torch.Tensor:
if not self.metadata.categorical_groups:
return x_candidate
# Projection must always leave each one-hot group with exactly one active feature.
x_projected = x_candidate.clone()
for group in self.metadata.categorical_groups:
group_tensor = torch.tensor(group, dtype=torch.long, device=x_candidate.device)
selected = x_candidate.index_select(1, group_tensor).argmax(dim=1)
x_projected[:, group_tensor] = F.one_hot(selected, num_classes=len(group)).to(dtype=x_candidate.dtype)
return x_projected
def _feature_type_mask(self, x: torch.Tensor, feature_type: str) -> torch.Tensor:
return torch.tensor(
[value == feature_type for value in self.metadata.feature_types],
dtype=torch.bool,
device=x.device,
).view(1, -1)
def _bounds(
self,
x_clean: torch.Tensor,
epsilon: float,
tensors: dict[str, torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Numeric features are restricted both by dataset bounds and by the epsilon ball around x_clean.
numeric_lower = torch.maximum(tensors["min"], x_clean - float(epsilon))
numeric_upper = torch.minimum(tensors["max"], x_clean + float(epsilon))
# Categorical features are handled by one-hot projection, not by an epsilon ball.
lower = torch.where(tensors["categorical"], tensors["min"], numeric_lower)
upper = torch.where(tensors["categorical"], tensors["max"], numeric_upper)
return lower, upper
def _integer_steps(self, minimum: torch.Tensor) -> torch.Tensor:
# Default step=1 is harmless for non-integer columns because the integer mask gates usage later.
integer_steps = torch.ones_like(minimum)
for idx, step in (self.metadata.integer_step_norm or {}).items():
integer_steps[0, int(idx)] = float(step)
return integer_steps
def _project_integer_features(
self,
x_projected: torch.Tensor,
x_clean: torch.Tensor,
lower: torch.Tensor,
upper: torch.Tensor,
tensors: dict[str, torch.Tensor],
) -> torch.Tensor:
integer_mask = tensors["integer"]
if not integer_mask.any():
return x_projected
# Integer features may be normalized, so the valid values form a shifted grid:
# min, min + step, min + 2*step, ...
step = torch.clamp(tensors["integer_step"], min=torch.finfo(x_projected.dtype).eps)
grid_lower = torch.ceil((lower - tensors["min"]) / step) * step + tensors["min"]
grid_upper = torch.floor((upper - tensors["min"]) / step) * step + tensors["min"]
rounded = torch.round((x_projected - tensors["min"]) / step) * step + tensors["min"]
rounded = torch.max(torch.min(rounded, grid_upper), grid_lower)
# If epsilon is smaller than the normalized integer step, no valid integer move exists.
has_valid_grid = grid_lower <= grid_upper
rounded = torch.where(has_valid_grid, rounded, x_clean)
return torch.where(integer_mask, rounded, x_projected)
|