-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add DDP support to hivemind.optim #475
base: master
Are you sure you want to change the base?
Conversation
|
||
class DDPOptimizer(Optimizer): | ||
_DDP_LEADER_RANK = 0 | ||
_BROADCAST_BUFFER_SIZE = 250 * 1024 ** 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New pytorch seems to have finally implemented broadcast_coalesced in distributed,
we can directly address this https://pytorch.org/docs/stable/_modules/torch/nn/parallel/comm.html#broadcast_coalesced as long as we bump minimal pytorch version. Wadayathink?
return torch.distributed.is_initialized() | ||
|
||
@staticmethod | ||
def is_ddp_leader(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: would recommend reusing the same terminology as somewhere, such as inside DistributedDataParallel
For instance, the above DDP uses
- leader rank -> authoritative rank
- is_ddp_enabled -> _initialized
Codecov Report
@@ Coverage Diff @@
## master #475 +/- ##
==========================================
- Coverage 83.45% 82.53% -0.93%
==========================================
Files 81 82 +1
Lines 8083 8175 +92
==========================================
+ Hits 6746 6747 +1
- Misses 1337 1428 +91
|
return self.is_ddp_leader() and super().is_alive() | ||
|
||
def _compute_state_version(self) -> int: | ||
"""Return a non-decreasing integer that goes up whenever model params and/or buffers were updated""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is meant as a workaround to catch the moment when optimizer has updated parameters (load from peers, apply optimizer step average params)
All changes to state are currently handled in StateAverager.
Maybe we can implement StateAverager.local_version that gets incremented every time StateAverager loads, averages or updates state by optimizer
hivemind/optim/ddp.py
Outdated
if self.is_ddp_leader(): | ||
super().load_state_from_peers(**kwargs) | ||
|
||
self._sync_among_ddp_ranks() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not synchronize here: non-master ranks cannot call this and we will deadlock.
We should only sync in step -- and after step check IF master updated/loaded/averaged step and then broadcast.
hivemind/optim/ddp.py
Outdated
if self.is_ddp_leader(): | ||
super().load_state_dict(state_dict) | ||
|
||
self._sync_among_ddp_ranks() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not synchronize here: non-master ranks cannot call this and we will deadlock: see load_state_from_peers
|
||
def shutdown(self): | ||
if self.is_ddp_leader(): | ||
super().shutdown() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional: else raise NotImplemented or warn?
hivemind/optim/ddp.py
Outdated
|
||
def is_alive(self) -> bool: | ||
# On followers, this always returns False since there's nothing to shut down in __del__() | ||
return self.is_ddp_leader() and super().is_alive() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if leader:
return is_alive
else:
raise NotImplementedError?
@@ -131,10 +131,10 @@ def __init__( | |||
) | |||
|
|||
@staticmethod | |||
def _check_params( | |||
def check_params( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def check_params( | |
def prepare_params( |
logger = get_logger(__name__) | ||
|
||
|
||
class DDPOptimizer(Optimizer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to self: A better way to do it is:
- Don't inherit
hivemind.Optimizer
- Make
_create_optimizer()
method and forward__init__
's kwargs there - Make
opt
property - Maybe create
__getattr__
that can forward attrs toopt
logger = get_logger(__name__) | ||
|
||
|
||
class DDPOptimizer(Optimizer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to self: A better way to do it is:
- Don't inherit
hivemind.Optimizer
- Make
_create_optimizer()
method and forward__init__
's kwargs there - Make
opt
property - Maybe create
__getattr__
that can forward attrs toopt
Status: This PR is an early draft intended to validate the design of
hivemind.DDPOptimizer
. I didn't run the code even once yet.Co-authored-by: @justheuristic