Skip to content

Documentation for Base Module

AdversarialExampleGenerator

Bases: ABC

Base interface for domain-specific adversarial example generators.

Source code in nebula/addons/defenses/adversarial_training/base.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class AdversarialExampleGenerator(ABC):
    """Base interface for domain-specific adversarial example generators."""

    last_epsilon: float | None = None

    @abstractmethod
    def generate(self, model, x, y, criterion):
        # Concrete generators must return an adversarial version of the input batch.
        raise NotImplementedError

    def _sample_epsilon(self, device: torch.device) -> float:
        # Sample the effective epsilon on the same device as the batch.
        epsilon_max = float(self.config.epsilon)
        if epsilon_max <= 0.0:
            self.last_epsilon = 0.0
            return 0.0

        # Use a different attack strength per batch, capped by the user epsilon.
        epsilon_min = epsilon_max / 4.0
        epsilon_step = epsilon_max / 8.0
        num_values = max(round((epsilon_max - epsilon_min) / epsilon_step) + 1, 1)
        index = int(torch.randint(num_values, (), device=device).item())
        epsilon = min(epsilon_min + index * epsilon_step, epsilon_max)
        self.last_epsilon = epsilon
        return epsilon