Skip to content

Documentation for Aggregator Module

Aggregator

Bases: ABC

Source code in nebula/core/aggregation/aggregator.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
class Aggregator(ABC):
    def __init__(self, config=None, engine=None):
        self.config = config
        self.engine: Engine = engine
        self._addr = config.participant["network_args"]["addr"]
        logging.info(f"[{self.__class__.__name__}] Starting Aggregator")
        self._federation_nodes = set()
        self._pending_models_to_aggregate = {}
        self._pending_models_to_aggregate_lock = Locker(name="pending_models_to_aggregate_lock", async_lock=True)
        self._aggregation_done_lock = Locker(name="aggregation_done_lock", async_lock=True)
        self._aggregation_waiting_skip = asyncio.Event()

        scenario = self.config.participant["scenario_args"]["federation"]
        self._update_storage = factory_update_handler(scenario, self, self._addr)

    def __str__(self):
        return self.__class__.__name__

    def __repr__(self):
        return self.__str__()

    @property
    def us(self):
        """Federation type UpdateHandler (e.g. DFL-UpdateHandler, CFL-UpdateHandler...)"""
        return self._update_storage

    @abstractmethod
    def run_aggregation(self, models):
        if len(models) == 0:
            logging.error("Trying to aggregate models when there are no models")
            return None

    async def init(self):
        await self.us.init(self.engine.rb.get_role_name(True))

    async def update_federation_nodes(self, federation_nodes: set):
        """
        Updates the current set of nodes expected to participate in the upcoming aggregation round.

        This method informs the update handler (`us`) about the new set of federation nodes, 
        clears any pending models, and attempts to acquire the aggregation lock to prepare 
        for model aggregation. If the aggregation process is already running, it releases the lock
        and tries again to ensure proper cleanup between rounds.

        Args:
            federation_nodes (set): A set of addresses representing the nodes expected to contribute 
                                    updates for the next aggregation round.

        Raises:
            Exception: If the aggregation process is already running and the lock cannot be released.
        """
        await self.us.round_expected_updates(federation_nodes=federation_nodes)

        # If the aggregation lock is held, release it to prepare for the new round
        if self._aggregation_done_lock.locked():
            logging.info("🔄  update_federation_nodes | Aggregation lock is held, releasing for new round")
            try:
                await self._aggregation_done_lock.release_async()
            except Exception as e:
                logging.warning(f"🔄  update_federation_nodes | Error releasing aggregation lock: {e}")
                # If we can't release the lock, we might be in the middle of aggregation
                # In this case, we should wait a bit and try again
                await asyncio.sleep(0.1)
                if self._aggregation_done_lock.locked():
                    raise Exception("It is not possible to set nodes to aggregate when the aggregation is running.")

        # Now acquire the lock for the new round
        self._federation_nodes = federation_nodes
        self._pending_models_to_aggregate.clear()
        await self._aggregation_done_lock.acquire_async(
            timeout=self.config.participant["aggregator_args"]["aggregation_timeout"]
        )

    def get_nodes_pending_models_to_aggregate(self):
        return self._federation_nodes

    async def get_aggregation(self):
        """
        Handles the aggregation process for a training round.

        This method waits for all expected model updates from federation nodes or until a timeout occurs.
        It uses an asynchronous lock to coordinate access and includes an early exit mechanism if all
        updates are received before the timeout. Once the condition is satisfied, it releases the lock,
        collects the updates, identifies any missing nodes, and publishes an `AggregationEvent`.
        Finally, it runs the aggregation algorithm and returns the result.

        Returns:
            Any: The result of the aggregation process, as returned by `run_aggregation`.

        Raises:
            TimeoutError: If the aggregation lock is not acquired within the defined timeout.
            asyncio.CancelledError: If the aggregation lock acquisition is cancelled.
            Exception: For any other unexpected errors during the aggregation process.
        """            
        try:
            timeout = self.config.participant["aggregator_args"]["aggregation_timeout"]
            logging.info(f"Aggregation timeout: {timeout} starts...")
            await self.us.notify_if_all_updates_received()
            lock_task = asyncio.create_task(self._aggregation_done_lock.acquire_async(timeout=timeout))
            skip_task = asyncio.create_task(self._aggregation_waiting_skip.wait())
            done, pending = await asyncio.wait(
                [lock_task, skip_task],
                return_when=asyncio.FIRST_COMPLETED,
            )
            lock_acquired = lock_task in done
            if skip_task in done:
                logging.info("Skipping aggregation timeout, updates received before grace time")
                self._aggregation_waiting_skip.clear()
                if not lock_acquired:
                    lock_task.cancel()
                try:
                    await lock_task  # Clean cancel
                except asyncio.CancelledError:
                    pass

        except TimeoutError:
            logging.exception("🔄  get_aggregation | Timeout reached for aggregation")
        except asyncio.CancelledError:
            logging.exception("🔄  get_aggregation | Lock acquisition was cancelled")
        except Exception as e:
            logging.exception(f"🔄  get_aggregation | Error acquiring lock: {e}")
        finally:
            if lock_acquired or self._aggregation_done_lock.locked():
                await self._aggregation_done_lock.release_async()

        await self.us.stop_notifying_updates()
        updates = await self.us.get_round_updates()
        if not updates:
            logging.info(f"🔄  get_aggregation | No updates has been received..resolving conflict to continue...")
            updates = {self._addr: await self.engine.resolve_missing_updates()}

        missing_nodes = await self.us.get_round_missing_nodes()
        if missing_nodes:
            logging.info(f"🔄  get_aggregation | Aggregation incomplete, missing models from: {missing_nodes}")
        else:
            logging.info("🔄  get_aggregation | All models accounted for, proceeding with aggregation.")

        agg_event = AggregationEvent(updates, self._federation_nodes, missing_nodes)
        await EventManager.get_instance().publish_node_event(agg_event)
        aggregated_result = self.run_aggregation(updates)
        return aggregated_result

    def print_model_size(self, model):
        total_memory = 0

        for _, param in model.items():
            num_params = param.numel()
            memory_usage = param.element_size() * num_params
            total_memory += memory_usage

        total_memory_in_mb = total_memory / (1024**2)
        logging.info(f"print_model_size | Model size: {total_memory_in_mb} MB")

    async def notify_all_updates_received(self):
        self._aggregation_waiting_skip.set()

us property

Federation type UpdateHandler (e.g. DFL-UpdateHandler, CFL-UpdateHandler...)

get_aggregation() async

Handles the aggregation process for a training round.

This method waits for all expected model updates from federation nodes or until a timeout occurs. It uses an asynchronous lock to coordinate access and includes an early exit mechanism if all updates are received before the timeout. Once the condition is satisfied, it releases the lock, collects the updates, identifies any missing nodes, and publishes an AggregationEvent. Finally, it runs the aggregation algorithm and returns the result.

Returns:

Name Type Description
Any

The result of the aggregation process, as returned by run_aggregation.

Raises:

Type Description
TimeoutError

If the aggregation lock is not acquired within the defined timeout.

CancelledError

If the aggregation lock acquisition is cancelled.

Exception

For any other unexpected errors during the aggregation process.

Source code in nebula/core/aggregation/aggregator.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
async def get_aggregation(self):
    """
    Handles the aggregation process for a training round.

    This method waits for all expected model updates from federation nodes or until a timeout occurs.
    It uses an asynchronous lock to coordinate access and includes an early exit mechanism if all
    updates are received before the timeout. Once the condition is satisfied, it releases the lock,
    collects the updates, identifies any missing nodes, and publishes an `AggregationEvent`.
    Finally, it runs the aggregation algorithm and returns the result.

    Returns:
        Any: The result of the aggregation process, as returned by `run_aggregation`.

    Raises:
        TimeoutError: If the aggregation lock is not acquired within the defined timeout.
        asyncio.CancelledError: If the aggregation lock acquisition is cancelled.
        Exception: For any other unexpected errors during the aggregation process.
    """            
    try:
        timeout = self.config.participant["aggregator_args"]["aggregation_timeout"]
        logging.info(f"Aggregation timeout: {timeout} starts...")
        await self.us.notify_if_all_updates_received()
        lock_task = asyncio.create_task(self._aggregation_done_lock.acquire_async(timeout=timeout))
        skip_task = asyncio.create_task(self._aggregation_waiting_skip.wait())
        done, pending = await asyncio.wait(
            [lock_task, skip_task],
            return_when=asyncio.FIRST_COMPLETED,
        )
        lock_acquired = lock_task in done
        if skip_task in done:
            logging.info("Skipping aggregation timeout, updates received before grace time")
            self._aggregation_waiting_skip.clear()
            if not lock_acquired:
                lock_task.cancel()
            try:
                await lock_task  # Clean cancel
            except asyncio.CancelledError:
                pass

    except TimeoutError:
        logging.exception("🔄  get_aggregation | Timeout reached for aggregation")
    except asyncio.CancelledError:
        logging.exception("🔄  get_aggregation | Lock acquisition was cancelled")
    except Exception as e:
        logging.exception(f"🔄  get_aggregation | Error acquiring lock: {e}")
    finally:
        if lock_acquired or self._aggregation_done_lock.locked():
            await self._aggregation_done_lock.release_async()

    await self.us.stop_notifying_updates()
    updates = await self.us.get_round_updates()
    if not updates:
        logging.info(f"🔄  get_aggregation | No updates has been received..resolving conflict to continue...")
        updates = {self._addr: await self.engine.resolve_missing_updates()}

    missing_nodes = await self.us.get_round_missing_nodes()
    if missing_nodes:
        logging.info(f"🔄  get_aggregation | Aggregation incomplete, missing models from: {missing_nodes}")
    else:
        logging.info("🔄  get_aggregation | All models accounted for, proceeding with aggregation.")

    agg_event = AggregationEvent(updates, self._federation_nodes, missing_nodes)
    await EventManager.get_instance().publish_node_event(agg_event)
    aggregated_result = self.run_aggregation(updates)
    return aggregated_result

update_federation_nodes(federation_nodes) async

Updates the current set of nodes expected to participate in the upcoming aggregation round.

This method informs the update handler (us) about the new set of federation nodes, clears any pending models, and attempts to acquire the aggregation lock to prepare for model aggregation. If the aggregation process is already running, it releases the lock and tries again to ensure proper cleanup between rounds.

Parameters:

Name Type Description Default
federation_nodes set

A set of addresses representing the nodes expected to contribute updates for the next aggregation round.

required

Raises:

Type Description
Exception

If the aggregation process is already running and the lock cannot be released.

Source code in nebula/core/aggregation/aggregator.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
async def update_federation_nodes(self, federation_nodes: set):
    """
    Updates the current set of nodes expected to participate in the upcoming aggregation round.

    This method informs the update handler (`us`) about the new set of federation nodes, 
    clears any pending models, and attempts to acquire the aggregation lock to prepare 
    for model aggregation. If the aggregation process is already running, it releases the lock
    and tries again to ensure proper cleanup between rounds.

    Args:
        federation_nodes (set): A set of addresses representing the nodes expected to contribute 
                                updates for the next aggregation round.

    Raises:
        Exception: If the aggregation process is already running and the lock cannot be released.
    """
    await self.us.round_expected_updates(federation_nodes=federation_nodes)

    # If the aggregation lock is held, release it to prepare for the new round
    if self._aggregation_done_lock.locked():
        logging.info("🔄  update_federation_nodes | Aggregation lock is held, releasing for new round")
        try:
            await self._aggregation_done_lock.release_async()
        except Exception as e:
            logging.warning(f"🔄  update_federation_nodes | Error releasing aggregation lock: {e}")
            # If we can't release the lock, we might be in the middle of aggregation
            # In this case, we should wait a bit and try again
            await asyncio.sleep(0.1)
            if self._aggregation_done_lock.locked():
                raise Exception("It is not possible to set nodes to aggregate when the aggregation is running.")

    # Now acquire the lock for the new round
    self._federation_nodes = federation_nodes
    self._pending_models_to_aggregate.clear()
    await self._aggregation_done_lock.acquire_async(
        timeout=self.config.participant["aggregator_args"]["aggregation_timeout"]
    )