Skip to content

Documentation for Cflupdatehandler Module

CFLUpdateHandler

Bases: UpdateHandler

Handles updates received in a cross-silo/federated learning setup, managing synchronization and aggregation across distributed nodes.

Attributes:

Name Type Description
_aggregator Aggregator

Reference to the aggregator managing the global model.

_addr str

Local address of the node.

_buffersize int

Max number of updates to store per node.

_updates_storage dict

Stores received updates per source node.

_sources_expected set

Set of nodes expected to send updates this round.

_sources_received set

Set of nodes that have sent updates this round.

_missing_ones set

Tracks nodes whose updates are missing.

_role str

Role of this node (e.g., trainer or server).

Source code in nebula/core/aggregation/updatehandlers/cflupdatehandler.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
class CFLUpdateHandler(UpdateHandler):
    """
    Handles updates received in a cross-silo/federated learning setup,
    managing synchronization and aggregation across distributed nodes.

    Attributes:
        _aggregator (Aggregator): Reference to the aggregator managing the global model.
        _addr (str): Local address of the node.
        _buffersize (int): Max number of updates to store per node.
        _updates_storage (dict): Stores received updates per source node.
        _sources_expected (set): Set of nodes expected to send updates this round.
        _sources_received (set): Set of nodes that have sent updates this round.
        _missing_ones (set): Tracks nodes whose updates are missing.
        _role (str): Role of this node (e.g., trainer or server).
    """

    def __init__(self, aggregator, addr, buffersize=MAX_UPDATE_BUFFER_SIZE):
        self._addr = addr
        self._aggregator: Aggregator = aggregator
        self._buffersize = buffersize
        self._updates_storage: dict[str, deque[Update]] = {}
        self._updates_storage_lock = Locker(name="updates_storage_lock", async_lock=True)
        self._sources_expected = set()
        self._sources_received = set()
        self._round_updates_lock = Locker(
            name="round_updates_lock", async_lock=True
        )  # se coge cuando se empieza a comprobar si estan todas las updates
        self._update_federation_lock = Locker(name="update_federation_lock", async_lock=True)
        self._notification_sent_lock = Locker(name="notification_sent_lock", async_lock=True)
        self._notification = False
        self._missing_ones = set()
        self._role = ""

    @property
    def us(self):
        """Returns the internal updates storage dictionary."""
        return self._updates_storage

    @property
    def agg(self):
        """Returns the aggregator instance."""
        return self._aggregator

    async def init(self, config):
        """
        Initializes the handler with the participant configuration,
        and subscribes to relevant node events.
        """
        self._role = config
        await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self.notify_federation_update)
        await EventManager.get_instance().subscribe_node_event(UpdateReceivedEvent, self.storage_update)

    async def round_expected_updates(self, federation_nodes: set):
        """
        Sets the expected nodes for the current training round and updates storage.

        Args:
            federation_nodes (set): Nodes expected to send updates this round.
        """
        await self._update_federation_lock.acquire_async()
        await self._updates_storage_lock.acquire_async()
        self._sources_expected = federation_nodes.copy()
        self._sources_received.clear()

        # Initialize new nodes
        for fn in federation_nodes:
            if fn not in self.us:
                self.us[fn] = deque(maxlen=self._buffersize)

        # Clear removed nodes
        removed_nodes = [node for node in self._updates_storage.keys() if node not in federation_nodes]
        for rn in removed_nodes:
            del self._updates_storage[rn]

        await self._updates_storage_lock.release_async()
        await self._update_federation_lock.release_async()

        # Lock to check if all updates received
        if self._round_updates_lock.locked():
            self._round_updates_lock.release_async()

        self._notification = False

    async def storage_update(self, updt_received_event: UpdateReceivedEvent):
        """
        Stores a received update if it comes from an expected source.

        Args:
            updt_received_event (UpdateReceivedEvent): The event containing the update.
        """
        time_received = time.time()
        (model, weight, source, round, _) = await updt_received_event.get_event_data()

        if source in self._sources_expected:
            updt = Update(model, weight, source, round, time_received)
            await self._updates_storage_lock.acquire_async()
            if updt in self.us[source]:
                logging.info(f"Discard | Alerady received update from source: {source} for round: {round}")
            else:
                self.us[source].append(updt)
                logging.info(
                    f"Storage Update | source={source} | round={round} | weight={weight} | federation nodes: {self._sources_expected}"
                )

                self._sources_received.add(source)
                updates_left = self._sources_expected.difference(self._sources_received)
                logging.info(
                    f"Updates received ({len(self._sources_received)}/{len(self._sources_expected)}) | Missing nodes: {updates_left}"
                )
                if self._round_updates_lock.locked() and not updates_left:
                    all_rec = await self._all_updates_received()
                    if all_rec:
                        await self._notify()
            await self._updates_storage_lock.release_async()
        else:
            if source not in self._sources_received:
                logging.info(f"Discard update | source: {source} not in expected updates for this Round")

    async def get_round_updates(self) -> dict[str, tuple[object, float]]:
        """
        Retrieves the latest updates received this round.

        Returns:
            dict: Mapping of source to (model, weight) tuples.
        """
        await self._updates_storage_lock.acquire_async()
        updates_missing = self._sources_expected.difference(self._sources_received)
        if updates_missing:
            self._missing_ones = updates_missing
            logging.info(f"Missing updates from sources: {updates_missing}")
        updates = {}
        for sr in self._sources_received:
            if (
                self._role == "trainer" and len(self._sources_received) > 1
            ):  # if trainer node ignore self updt if has received udpate from server
                if sr == self._addr:
                    continue
            source_historic = self.us[sr]
            updt: Update = None
            updt = source_historic[-1]  # Get last update received
            updates[sr] = (updt.model, updt.weight)
        await self._updates_storage_lock.release_async()
        return updates

    async def notify_federation_update(self, updt_nei_event: UpdateNeighborEvent):
        """
        Reacts to neighbor updates (e.g., join or leave).

        Args:
            updt_nei_event (UpdateNeighborEvent): The neighbor update event.
        """
        source, remove = await updt_nei_event.get_event_data()
        if not remove:
            if self._round_updates_lock.locked():
                logging.info(f"Source: {source} will be count next round")
            else:
                await self._update_source(source, remove)
        else:
            if source not in self._sources_received:  # Not received update from this source yet
                await self._update_source(source, remove=True)
                await self._all_updates_received()  # Verify if discarding node aggregation could be done
            else:
                logging.info(f"Already received update from: {source}, it will be discarded next round")

    async def _update_source(self, source, remove=False):
        """
        Updates internal tracking for a specific source node.

        Args:
            source (str): Source node ID.
            remove (bool): Whether the source should be removed.
        """
        logging.info(f"🔄 Update | remove: {remove} | source: {source}")
        await self._updates_storage_lock.acquire_async()
        if remove:
            self._sources_expected.discard(source)
        else:
            self.us[source] = deque(maxlen=self._buffersize)
            self._sources_expected.add(source)
        logging.info(f"federation nodes expected this round: {self._sources_expected}")
        await self._updates_storage_lock.release_async()

    async def get_round_missing_nodes(self):
        """
        Returns nodes whose updates were expected but not received.
        """
        return self._missing_ones

    async def notify_if_all_updates_received(self):
        """
        Acquires a lock to notify the aggregator if all updates have been received.
        """
        logging.info("Set notification when all expected updates received")
        await self._round_updates_lock.acquire_async()
        await self._updates_storage_lock.acquire_async()
        all_received = await self._all_updates_received()
        await self._updates_storage_lock.release_async()
        if all_received:
            await self._notify()

    async def stop_notifying_updates(self):
        """
        Stops waiting for updates and releases the notification lock if held.
        """
        if self._round_updates_lock.locked():
            logging.info("Stop notification updates")
            await self._round_updates_lock.release_async()

    async def _notify(self):
        """
        Notifies the aggregator that all updates have been received.
        """
        await self._notification_sent_lock.acquire_async()
        if self._notification:
            await self._notification_sent_lock.release_async()
            return
        self._notification = True
        await self.stop_notifying_updates()
        await self._notification_sent_lock.release_async()
        logging.info("🔄 Notifying aggregator to release aggregation")
        await self.agg.notify_all_updates_received()

    async def _all_updates_received(self):
        """
        Checks if updates from all expected nodes have been received.

        Returns:
            bool: True if all updates are received, False otherwise.
        """
        updates_left = self._sources_expected.difference(self._sources_received)
        all_received = False
        if len(updates_left) == 0:
            logging.info("All updates have been received this round")
            await self._round_updates_lock.release_async()
            all_received = True
        return all_received

agg property

Returns the aggregator instance.

us property

Returns the internal updates storage dictionary.

get_round_missing_nodes() async

Returns nodes whose updates were expected but not received.

Source code in nebula/core/aggregation/updatehandlers/cflupdatehandler.py
225
226
227
228
229
async def get_round_missing_nodes(self):
    """
    Returns nodes whose updates were expected but not received.
    """
    return self._missing_ones

get_round_updates() async

Retrieves the latest updates received this round.

Returns:

Name Type Description
dict dict[str, tuple[object, float]]

Mapping of source to (model, weight) tuples.

Source code in nebula/core/aggregation/updatehandlers/cflupdatehandler.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
async def get_round_updates(self) -> dict[str, tuple[object, float]]:
    """
    Retrieves the latest updates received this round.

    Returns:
        dict: Mapping of source to (model, weight) tuples.
    """
    await self._updates_storage_lock.acquire_async()
    updates_missing = self._sources_expected.difference(self._sources_received)
    if updates_missing:
        self._missing_ones = updates_missing
        logging.info(f"Missing updates from sources: {updates_missing}")
    updates = {}
    for sr in self._sources_received:
        if (
            self._role == "trainer" and len(self._sources_received) > 1
        ):  # if trainer node ignore self updt if has received udpate from server
            if sr == self._addr:
                continue
        source_historic = self.us[sr]
        updt: Update = None
        updt = source_historic[-1]  # Get last update received
        updates[sr] = (updt.model, updt.weight)
    await self._updates_storage_lock.release_async()
    return updates

init(config) async

Initializes the handler with the participant configuration, and subscribes to relevant node events.

Source code in nebula/core/aggregation/updatehandlers/cflupdatehandler.py
86
87
88
89
90
91
92
93
async def init(self, config):
    """
    Initializes the handler with the participant configuration,
    and subscribes to relevant node events.
    """
    self._role = config
    await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self.notify_federation_update)
    await EventManager.get_instance().subscribe_node_event(UpdateReceivedEvent, self.storage_update)

notify_federation_update(updt_nei_event) async

Reacts to neighbor updates (e.g., join or leave).

Parameters:

Name Type Description Default
updt_nei_event UpdateNeighborEvent

The neighbor update event.

required
Source code in nebula/core/aggregation/updatehandlers/cflupdatehandler.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
async def notify_federation_update(self, updt_nei_event: UpdateNeighborEvent):
    """
    Reacts to neighbor updates (e.g., join or leave).

    Args:
        updt_nei_event (UpdateNeighborEvent): The neighbor update event.
    """
    source, remove = await updt_nei_event.get_event_data()
    if not remove:
        if self._round_updates_lock.locked():
            logging.info(f"Source: {source} will be count next round")
        else:
            await self._update_source(source, remove)
    else:
        if source not in self._sources_received:  # Not received update from this source yet
            await self._update_source(source, remove=True)
            await self._all_updates_received()  # Verify if discarding node aggregation could be done
        else:
            logging.info(f"Already received update from: {source}, it will be discarded next round")

notify_if_all_updates_received() async

Acquires a lock to notify the aggregator if all updates have been received.

Source code in nebula/core/aggregation/updatehandlers/cflupdatehandler.py
231
232
233
234
235
236
237
238
239
240
241
async def notify_if_all_updates_received(self):
    """
    Acquires a lock to notify the aggregator if all updates have been received.
    """
    logging.info("Set notification when all expected updates received")
    await self._round_updates_lock.acquire_async()
    await self._updates_storage_lock.acquire_async()
    all_received = await self._all_updates_received()
    await self._updates_storage_lock.release_async()
    if all_received:
        await self._notify()

round_expected_updates(federation_nodes) async

Sets the expected nodes for the current training round and updates storage.

Parameters:

Name Type Description Default
federation_nodes set

Nodes expected to send updates this round.

required
Source code in nebula/core/aggregation/updatehandlers/cflupdatehandler.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
async def round_expected_updates(self, federation_nodes: set):
    """
    Sets the expected nodes for the current training round and updates storage.

    Args:
        federation_nodes (set): Nodes expected to send updates this round.
    """
    await self._update_federation_lock.acquire_async()
    await self._updates_storage_lock.acquire_async()
    self._sources_expected = federation_nodes.copy()
    self._sources_received.clear()

    # Initialize new nodes
    for fn in federation_nodes:
        if fn not in self.us:
            self.us[fn] = deque(maxlen=self._buffersize)

    # Clear removed nodes
    removed_nodes = [node for node in self._updates_storage.keys() if node not in federation_nodes]
    for rn in removed_nodes:
        del self._updates_storage[rn]

    await self._updates_storage_lock.release_async()
    await self._update_federation_lock.release_async()

    # Lock to check if all updates received
    if self._round_updates_lock.locked():
        self._round_updates_lock.release_async()

    self._notification = False

stop_notifying_updates() async

Stops waiting for updates and releases the notification lock if held.

Source code in nebula/core/aggregation/updatehandlers/cflupdatehandler.py
243
244
245
246
247
248
249
async def stop_notifying_updates(self):
    """
    Stops waiting for updates and releases the notification lock if held.
    """
    if self._round_updates_lock.locked():
        logging.info("Stop notification updates")
        await self._round_updates_lock.release_async()

storage_update(updt_received_event) async

Stores a received update if it comes from an expected source.

Parameters:

Name Type Description Default
updt_received_event UpdateReceivedEvent

The event containing the update.

required
Source code in nebula/core/aggregation/updatehandlers/cflupdatehandler.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
async def storage_update(self, updt_received_event: UpdateReceivedEvent):
    """
    Stores a received update if it comes from an expected source.

    Args:
        updt_received_event (UpdateReceivedEvent): The event containing the update.
    """
    time_received = time.time()
    (model, weight, source, round, _) = await updt_received_event.get_event_data()

    if source in self._sources_expected:
        updt = Update(model, weight, source, round, time_received)
        await self._updates_storage_lock.acquire_async()
        if updt in self.us[source]:
            logging.info(f"Discard | Alerady received update from source: {source} for round: {round}")
        else:
            self.us[source].append(updt)
            logging.info(
                f"Storage Update | source={source} | round={round} | weight={weight} | federation nodes: {self._sources_expected}"
            )

            self._sources_received.add(source)
            updates_left = self._sources_expected.difference(self._sources_received)
            logging.info(
                f"Updates received ({len(self._sources_received)}/{len(self._sources_expected)}) | Missing nodes: {updates_left}"
            )
            if self._round_updates_lock.locked() and not updates_left:
                all_rec = await self._all_updates_received()
                if all_rec:
                    await self._notify()
        await self._updates_storage_lock.release_async()
    else:
        if source not in self._sources_received:
            logging.info(f"Discard update | source: {source} not in expected updates for this Round")

Update

Represents a model update received from a node in a specific training round.

Attributes:

Name Type Description
model object

The model object or weights received.

weight float

The weight or importance of the update.

source str

Identifier of the node that sent the update.

round int

Training round this update belongs to.

time_received float

Timestamp when the update was received.

Source code in nebula/core/aggregation/updatehandlers/cflupdatehandler.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class Update:
    """
    Represents a model update received from a node in a specific training round.

    Attributes:
        model (object): The model object or weights received.
        weight (float): The weight or importance of the update.
        source (str): Identifier of the node that sent the update.
        round (int): Training round this update belongs to.
        time_received (float): Timestamp when the update was received.
    """
    def __init__(self, model, weight, source, round, time_received):
        self.model = model
        self.weight = weight
        self.source = source
        self.round = round
        self.time_received = time_received

    def __eq__(self, other):
        """
        Checks if two updates belong to the same round.
        """
        return self.round == other.round

__eq__(other)

Checks if two updates belong to the same round.

Source code in nebula/core/aggregation/updatehandlers/cflupdatehandler.py
33
34
35
36
37
def __eq__(self, other):
    """
    Checks if two updates belong to the same round.
    """
    return self.round == other.round