Skip to content

Documentation for Adversarial_training Module

AdversarialExampleGenerator

Bases: ABC

Base interface for domain-specific adversarial example generators.

Source code in nebula/addons/defenses/adversarial_training/base.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class AdversarialExampleGenerator(ABC):
    """Base interface for domain-specific adversarial example generators."""

    last_epsilon: float | None = None

    @abstractmethod
    def generate(self, model, x, y, criterion):
        # Concrete generators must return an adversarial version of the input batch.
        raise NotImplementedError

    def _sample_epsilon(self, device: torch.device) -> float:
        # Sample the effective epsilon on the same device as the batch.
        epsilon_max = float(self.config.epsilon)
        if epsilon_max <= 0.0:
            self.last_epsilon = 0.0
            return 0.0

        # Use a different attack strength per batch, capped by the user epsilon.
        epsilon_min = epsilon_max / 4.0
        epsilon_step = epsilon_max / 8.0
        num_values = max(round((epsilon_max - epsilon_min) / epsilon_step) + 1, 1)
        index = int(torch.randint(num_values, (), device=device).item())
        epsilon = min(epsilon_min + index * epsilon_step, epsilon_max)
        self.last_epsilon = epsilon
        return epsilon

AdversarialTrainingDefense

Batch-level adversarial training defense for Nebula models.

Source code in nebula/addons/defenses/adversarial_training/defense.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
class AdversarialTrainingDefense:
    """Batch-level adversarial training defense for Nebula models."""

    LOGGED_SAMPLES_PER_ROUND = AdversarialTrainingSampleLogger.LOGGED_SAMPLES_PER_ROUND

    def __init__(self, config: AdversarialTrainingConfig, generator: AdversarialExampleGenerator):
        # Keep the selected generator and logger together for each participant model.
        self.config = config
        self.generator = generator
        self.sample_logger = AdversarialTrainingSampleLogger(config, generator)
        self._logged_adversarial_samples_by_round = self.sample_logger._logged_samples_by_round

    @classmethod
    def from_participant_config(
        cls,
        participant_config: dict[str, Any],
        partition=None,
    ) -> "AdversarialTrainingDefense | None":
        # This is the only entry point used by Nebula's node setup.
        config = config_from_participant(participant_config)
        if config is None:
            return None
        validate_config(config)

        if config.domain == "tabular":
            metadata = cls._get_tabular_metadata(partition)
            return cls(config=config, generator=TabularConstrainedPGDGenerator(config, metadata))

        if config.domain == "image":
            # Image attacks run in normalized model space, so each dataset must provide mean/std.
            normalization = get_image_normalization(config.dataset_name)
            if normalization is None:
                logging.warning(
                    "[AdversarialTrainingDefense] Skipping adversarial training: dataset '%s' has no image bounds",
                    config.dataset_name,
                )
                return None

            return cls(config=config, generator=cls._build_image_generator(config, normalization))

        logging.warning(
            "[AdversarialTrainingDefense] Skipping adversarial training: domain '%s' is not implemented yet",
            config.domain,
        )
        return None

    @staticmethod
    def _build_image_generator(config, normalization):
        # Choose the image attack implementation requested by the participant config.
        mean, std = normalization
        if config.attack == "fgsm":
            return ImageFGSMGenerator(config, mean, std)
        if config.attack == "pgd":
            return ImagePGDGenerator(config, mean, std)
        raise ValueError(ERR_UNSUPPORTED_ATTACK.format(attack=config.attack))

    @staticmethod
    def _get_tabular_metadata(partition) -> TabularAdversarialMetadata:
        # Load the tabular constraints from the local training partition.
        train_set = getattr(partition, "train_set", None) if partition is not None else None
        metadata = getattr(train_set, "tabular_metadata", None)
        if metadata is None:
            raise ValueError(ERR_TABULAR_METADATA)
        # Metadata can come from an in-memory dataset object or from a serialized config.
        if isinstance(metadata, TabularAdversarialMetadata):
            tabular_metadata = metadata
        else:
            tabular_metadata = TabularAdversarialMetadata.from_dict(metadata)

        _log_tabular_metadata(tabular_metadata)
        return tabular_metadata

    def should_apply(self, x: torch.Tensor) -> bool:
        # Allows adversarial training to be applied to only a fraction of batches.
        if self.config.apply_probability >= 1.0:
            return True
        if self.config.apply_probability <= 0.0:
            return False
        return bool(torch.rand((), device=x.device).item() < self.config.apply_probability)

    def compute_training_step(self, model, x, y, criterion):
        if not self.should_apply(x):
            logits = model(x)
            loss = criterion(logits, y)
            return loss, logits, {}

        # Generate x_adv once and reuse it for logging, adversarial loss and metrics.
        x_adv = self.generator.generate(model, x, y, criterion)
        self._log_adversarial_samples(model, x, x_adv, y)
        adv_logits = model(x_adv)
        adv_loss = criterion(adv_logits, y)

        # "adversarial" replaces the clean batch loss completely.
        if self.config.mode == "adversarial":
            return adv_loss, adv_logits, self._extra_metrics({
                "Adversarial Loss": adv_loss,
                "Adversarial Accuracy": self._accuracy(adv_logits, y),
            })

        clean_logits = model(x)
        clean_loss = criterion(clean_logits, y)
        # "mixed" uses a fixed 50/50 clean/adversarial objective.
        loss = self.config.clean_weight * clean_loss + self.config.adversarial_weight * adv_loss

        return loss, clean_logits, self._extra_metrics({
            "Clean Loss": clean_loss,
            "Adversarial Loss": adv_loss,
            "Adversarial Accuracy": self._accuracy(adv_logits, y),
        })

    def _log_adversarial_samples(self, model, x_clean: torch.Tensor, x_adv: torch.Tensor, y: torch.Tensor) -> None:
        # Delegate logging so the training step stays focused on loss computation.
        self.sample_logger.log(model, x_clean, x_adv, y)

    def _accuracy(self, logits, y):
        # Compute batch accuracy from model logits.
        predictions = torch.argmax(logits, dim=1)
        return torch.mean((predictions == y).float())

    def _extra_metrics(self, metrics):
        # Allow users to disable adversarial metrics without changing the training loss.
        if not self.config.log_adversarial_metrics:
            return {}
        return metrics

TabularAdversarialExampleGenerator

Bases: AdversarialExampleGenerator

Base generator for constrained tabular adversarial examples.

Source code in nebula/addons/defenses/adversarial_training/tabular.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class TabularAdversarialExampleGenerator(AdversarialExampleGenerator):
    """Base generator for constrained tabular adversarial examples."""

    def __init__(self, config: AdversarialTrainingConfig, metadata: TabularAdversarialMetadata):
        # Generators share the same constraint layer; only the search strategy should vary.
        self.config = config
        self.metadata = metadata
        self.constraints = TabularConstraintSet(metadata)

    def _alpha(self, epsilon: float) -> float:
        # By default, distribute the epsilon budget evenly across constrained PGD steps.
        if self.config.alpha is not None:
            return float(self.config.alpha)
        return float(epsilon) / max(int(self.config.steps), 1)

    def _margin(self, logits: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        # Positive margin means some wrong class already beats the true class.
        true_logits = logits.gather(1, y.view(-1, 1)).squeeze(1)
        true_class_mask = F.one_hot(y, num_classes=logits.size(1)).bool()
        other_logits = logits.masked_fill(true_class_mask, float("-inf"))
        return other_logits.max(dim=1).values - true_logits

    def _per_sample_loss(self, logits: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        # The attack needs per-sample scores so each row can stop once it is hard enough.
        return F.cross_entropy(logits, y, reduction="none")

TabularConstrainedPGDGenerator

Bases: TabularAdversarialExampleGenerator

Constrained PGD generator for tabular adversarial examples.

Source code in nebula/addons/defenses/adversarial_training/tabular.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
class TabularConstrainedPGDGenerator(TabularAdversarialExampleGenerator):
    """Constrained PGD generator for tabular adversarial examples."""

    def generate(self, model, x, y, criterion):
        # Sample one attack strength for this batch, matching the image generator behavior.
        epsilon = self._sample_epsilon(x.device)
        x_clean = x.detach()
        if epsilon <= 0.0:
            return x_clean

        steps = max(int(self.config.steps), 1)
        step_size = self._alpha(epsilon)
        perturbable_mask = self.constraints.perturbable_mask(x_clean).to(dtype=x_clean.dtype)

        x_adv = x_clean.clone()
        best_adv = x_adv.clone()
        best_score = torch.full((x_clean.size(0),), float("-inf"), dtype=x_clean.dtype, device=x_clean.device)
        best_distance = torch.full((x_clean.size(0),), float("inf"), dtype=x_clean.dtype, device=x_clean.device)
        use_loss_window = self._use_loss_window()
        use_margin_window = self._use_margin_window()
        clean_loss = self._clean_loss(model, x_clean, y) if use_loss_window else None

        for _ in range(steps):
            # PGD step: move in the sign of the loss gradient, but only on perturbable features.
            x_grad = x_adv.detach().requires_grad_(True)
            logits = model(x_grad)
            loss = criterion(logits, y)
            grad = torch.autograd.grad(loss, x_grad, only_inputs=True)[0]

            candidate = x_adv.detach() + float(step_size) * grad.sign() * perturbable_mask
            candidate = self.constraints.categorical_gradient_step(candidate, grad)
            # This is the key tabular rule: never score or return an invalid candidate.
            candidate = self.constraints.project(candidate, x_clean, epsilon)

            with torch.no_grad():
                # Keep the best candidate per sample, not just the last step.
                candidate_logits = model(candidate)
                if use_loss_window:
                    candidate_score = self._loss_increase(candidate_logits, y, clean_loss)
                    better = self._loss_window_better(candidate_score, best_score)
                elif use_margin_window:
                    candidate_score = self._margin(candidate_logits, y)
                    candidate_distance = self._margin_window_distance(candidate_score)
                    better = self._margin_window_better(candidate_score, candidate_distance, best_score, best_distance)
                    best_distance = torch.where(better, candidate_distance, best_distance)
                else:
                    candidate_score = self._margin(candidate_logits, y)
                    better = candidate_score > best_score
                best_adv = torch.where(better.view(-1, 1), candidate, best_adv)
                best_score = torch.where(better, candidate_score, best_score)

                if self._target_reached(best_score, best_distance):
                    break

            x_adv = candidate

        return best_adv.detach()

    def _use_loss_window(self) -> bool:
        return self.config.candidate_selection == "loss_window"

    def _use_margin_window(self) -> bool:
        return self.config.candidate_selection == "margin_window"

    def _clean_loss(self, model, x_clean: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        # Baseline difficulty. Candidate scores become loss(candidate) - loss(clean).
        with torch.no_grad():
            return self._per_sample_loss(model(x_clean), y)

    def _loss_increase(
        self,
        candidate_logits: torch.Tensor,
        y: torch.Tensor,
        clean_loss: torch.Tensor,
    ) -> torch.Tensor:
        return self._per_sample_loss(candidate_logits, y) - clean_loss

    def _loss_window_better(self, candidate_score: torch.Tensor, best_score: torch.Tensor) -> torch.Tensor:
        # A candidate must make the sample harder. If max_loss_increase is set, reject overshoots.
        valid = candidate_score > 0.0
        if self.config.max_loss_increase is not None:
            valid = valid & (candidate_score <= float(self.config.max_loss_increase))
        return valid & (candidate_score > best_score)

    def _margin_window_distance(self, margin: torch.Tensor) -> torch.Tensor:
        # Distance is zero inside the window and positive outside. This gives a
        # soft fallback when discrete tabular steps jump over the desired range.
        distance = torch.zeros_like(margin)
        if self.config.target_margin is not None:
            target = torch.full_like(margin, float(self.config.target_margin))
            distance = torch.maximum(distance, target - margin)
        if self.config.max_margin is not None:
            maximum = torch.full_like(margin, float(self.config.max_margin))
            distance = torch.maximum(distance, margin - maximum)
        return distance

    def _margin_window_better(
        self,
        candidate_score: torch.Tensor,
        candidate_distance: torch.Tensor,
        best_score: torch.Tensor,
        best_distance: torch.Tensor,
    ) -> torch.Tensor:
        closer = candidate_distance < best_distance
        same_distance = candidate_distance == best_distance
        stronger = candidate_score > best_score
        return closer | (same_distance & stronger)

    def _target_reached(self, best_score: torch.Tensor, best_distance: torch.Tensor) -> bool:
        if self._use_loss_window():
            if self.config.target_loss_increase is None:
                return False
            return bool((best_score >= float(self.config.target_loss_increase)).all().item())
        if self._use_margin_window():
            return bool((best_distance <= torch.finfo(best_distance.dtype).eps).all().item())
        return False

TabularConstraintSet

Projects tabular attack candidates back to the valid feature domain.

Source code in nebula/addons/defenses/adversarial_training/tabular.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class TabularConstraintSet:
    """Projects tabular attack candidates back to the valid feature domain."""

    def __init__(self, metadata: TabularAdversarialMetadata):
        # The metadata is dataset-level and immutable; derived tensors are cached per device/dtype.
        self.metadata = metadata
        self._tensor_cache: dict[tuple[torch.device, torch.dtype], dict[str, torch.Tensor]] = {}

    def tensors(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
        # Masks and bounds are reused in every constrained PGD step, so build them once per placement.
        key = (x.device, x.dtype)
        cached = self._tensor_cache.get(key)
        if cached is not None:
            return cached

        # Masks have shape (1, n_features), which broadcasts over the batch dimension.
        cached = {
            "continuous": self._feature_type_mask(x, CONTINUOUS),
            "integer": self._feature_type_mask(x, INTEGER),
            "categorical": self._feature_type_mask(x, CATEGORICAL),
            "min": torch.tensor(self.metadata.feature_min_norm, dtype=x.dtype, device=x.device).view(1, -1),
            "max": torch.tensor(self.metadata.feature_max_norm, dtype=x.dtype, device=x.device).view(1, -1),
        }
        cached["numeric"] = cached["continuous"] | cached["integer"]
        cached["perturbable"] = cached["numeric"] | cached["categorical"]
        cached["integer_step"] = self._integer_steps(cached["min"])
        self._tensor_cache[key] = cached
        return cached

    def perturbable_mask(self, x: torch.Tensor) -> torch.Tensor:
        # Used by the attack step to avoid moving immutable features in the first place.
        return self.tensors(x)["perturbable"]

    def project(self, x_candidate: torch.Tensor, x_clean: torch.Tensor, epsilon: float) -> torch.Tensor:
        # Clamp numeric features, round integers, restore immutable features and fix one-hot groups.
        tensors = self.tensors(x_clean)
        lower, upper = self._bounds(x_clean, epsilon, tensors)

        # First force every value into its valid interval, then apply type-specific fixes.
        x_projected = torch.max(torch.min(x_candidate, upper), lower)
        x_projected = self._project_integer_features(x_projected, x_clean, lower, upper, tensors)
        x_projected = self.project_categorical_groups(x_projected)
        # Immutable features are copied back from the original clean sample as the final guardrail.
        return torch.where(tensors["perturbable"], x_projected, x_clean)

    def categorical_gradient_step(self, x_candidate: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
        if not self.metadata.categorical_groups:
            return x_candidate

        # One-hot columns are discrete: instead of adding a fractional gradient,
        # activate the category whose gradient most increases the adversarial loss.
        x_stepped = x_candidate.clone()
        for group in self.metadata.categorical_groups:
            group_tensor = torch.tensor(group, dtype=torch.long, device=x_candidate.device)
            selected = grad.index_select(1, group_tensor).argmax(dim=1)
            x_stepped[:, group_tensor] = F.one_hot(selected, num_classes=len(group)).to(dtype=x_candidate.dtype)
        return x_stepped

    def project_categorical_groups(self, x_candidate: torch.Tensor) -> torch.Tensor:
        if not self.metadata.categorical_groups:
            return x_candidate

        # Projection must always leave each one-hot group with exactly one active feature.
        x_projected = x_candidate.clone()
        for group in self.metadata.categorical_groups:
            group_tensor = torch.tensor(group, dtype=torch.long, device=x_candidate.device)
            selected = x_candidate.index_select(1, group_tensor).argmax(dim=1)
            x_projected[:, group_tensor] = F.one_hot(selected, num_classes=len(group)).to(dtype=x_candidate.dtype)
        return x_projected

    def _feature_type_mask(self, x: torch.Tensor, feature_type: str) -> torch.Tensor:
        return torch.tensor(
            [value == feature_type for value in self.metadata.feature_types],
            dtype=torch.bool,
            device=x.device,
        ).view(1, -1)

    def _bounds(
        self,
        x_clean: torch.Tensor,
        epsilon: float,
        tensors: dict[str, torch.Tensor],
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # Numeric features are restricted both by dataset bounds and by the epsilon ball around x_clean.
        numeric_lower = torch.maximum(tensors["min"], x_clean - float(epsilon))
        numeric_upper = torch.minimum(tensors["max"], x_clean + float(epsilon))
        # Categorical features are handled by one-hot projection, not by an epsilon ball.
        lower = torch.where(tensors["categorical"], tensors["min"], numeric_lower)
        upper = torch.where(tensors["categorical"], tensors["max"], numeric_upper)
        return lower, upper

    def _integer_steps(self, minimum: torch.Tensor) -> torch.Tensor:
        # Default step=1 is harmless for non-integer columns because the integer mask gates usage later.
        integer_steps = torch.ones_like(minimum)
        for idx, step in (self.metadata.integer_step_norm or {}).items():
            integer_steps[0, int(idx)] = float(step)
        return integer_steps

    def _project_integer_features(
        self,
        x_projected: torch.Tensor,
        x_clean: torch.Tensor,
        lower: torch.Tensor,
        upper: torch.Tensor,
        tensors: dict[str, torch.Tensor],
    ) -> torch.Tensor:
        integer_mask = tensors["integer"]
        if not integer_mask.any():
            return x_projected

        # Integer features may be normalized, so the valid values form a shifted grid:
        # min, min + step, min + 2*step, ...
        step = torch.clamp(tensors["integer_step"], min=torch.finfo(x_projected.dtype).eps)
        grid_lower = torch.ceil((lower - tensors["min"]) / step) * step + tensors["min"]
        grid_upper = torch.floor((upper - tensors["min"]) / step) * step + tensors["min"]
        rounded = torch.round((x_projected - tensors["min"]) / step) * step + tensors["min"]
        rounded = torch.max(torch.min(rounded, grid_upper), grid_lower)

        # If epsilon is smaller than the normalized integer step, no valid integer move exists.
        has_valid_grid = grid_lower <= grid_upper
        rounded = torch.where(has_valid_grid, rounded, x_clean)
        return torch.where(integer_mask, rounded, x_projected)