Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DDP support to hivemind.optim #475

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft

Add DDP support to hivemind.optim #475

wants to merge 2 commits into from

Conversation

borzunov
Copy link
Member

@borzunov borzunov commented May 31, 2022

Status: This PR is an early draft intended to validate the design of hivemind.DDPOptimizer. I didn't run the code even once yet.

Co-authored-by: @justheuristic


class DDPOptimizer(Optimizer):
_DDP_LEADER_RANK = 0
_BROADCAST_BUFFER_SIZE = 250 * 1024 ** 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New pytorch seems to have finally implemented broadcast_coalesced in distributed,
we can directly address this https://pytorch.org/docs/stable/_modules/torch/nn/parallel/comm.html#broadcast_coalesced as long as we bump minimal pytorch version. Wadayathink?

return torch.distributed.is_initialized()

@staticmethod
def is_ddp_leader():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would recommend reusing the same terminology as somewhere, such as inside DistributedDataParallel

For instance, the above DDP uses

  • leader rank -> authoritative rank
  • is_ddp_enabled -> _initialized

@codecov
Copy link

codecov bot commented May 31, 2022

Codecov Report

Merging #475 (83ff269) into master (97deaee) will decrease coverage by 0.92%.
The diff coverage is 5.00%.

@@            Coverage Diff             @@
##           master     #475      +/-   ##
==========================================
- Coverage   83.45%   82.53%   -0.93%     
==========================================
  Files          81       82       +1     
  Lines        8083     8175      +92     
==========================================
+ Hits         6746     6747       +1     
- Misses       1337     1428      +91     
Impacted Files Coverage Δ
hivemind/optim/ddp.py 0.00% <0.00%> (ø)
hivemind/optim/optimizer.py 69.40% <100.00%> (-0.26%) ⬇️
hivemind/optim/state_averager.py 86.09% <100.00%> (ø)
hivemind/optim/progress_tracker.py 97.80% <0.00%> (-1.10%) ⬇️
hivemind/averaging/matchmaking.py 84.52% <0.00%> (+0.59%) ⬆️
hivemind/averaging/averager.py 89.07% <0.00%> (+0.71%) ⬆️
hivemind/utils/asyncio.py 100.00% <0.00%> (+0.86%) ⬆️

return self.is_ddp_leader() and super().is_alive()

def _compute_state_version(self) -> int:
"""Return a non-decreasing integer that goes up whenever model params and/or buffers were updated"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is meant as a workaround to catch the moment when optimizer has updated parameters (load from peers, apply optimizer step average params)

All changes to state are currently handled in StateAverager.
Maybe we can implement StateAverager.local_version that gets incremented every time StateAverager loads, averages or updates state by optimizer

if self.is_ddp_leader():
super().load_state_from_peers(**kwargs)

self._sync_among_ddp_ranks()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not synchronize here: non-master ranks cannot call this and we will deadlock.

We should only sync in step -- and after step check IF master updated/loaded/averaged step and then broadcast.

if self.is_ddp_leader():
super().load_state_dict(state_dict)

self._sync_among_ddp_ranks()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not synchronize here: non-master ranks cannot call this and we will deadlock: see load_state_from_peers


def shutdown(self):
if self.is_ddp_leader():
super().shutdown()
Copy link
Member

@justheuristic justheuristic May 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: else raise NotImplemented or warn?


def is_alive(self) -> bool:
# On followers, this always returns False since there's nothing to shut down in __del__()
return self.is_ddp_leader() and super().is_alive()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if leader:
return is_alive
else:
raise NotImplementedError?

@@ -131,10 +131,10 @@ def __init__(
)

@staticmethod
def _check_params(
def check_params(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def check_params(
def prepare_params(

logger = get_logger(__name__)


class DDPOptimizer(Optimizer):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: A better way to do it is:

  • Don't inherit hivemind.Optimizer
  • Make _create_optimizer() method and forward __init__'s kwargs there
  • Make opt property
  • Maybe create __getattr__ that can forward attrs to opt

logger = get_logger(__name__)


class DDPOptimizer(Optimizer):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: A better way to do it is:

  • Don't inherit hivemind.Optimizer
  • Make _create_optimizer() method and forward __init__'s kwargs there
  • Make opt property
  • Maybe create __getattr__ that can forward attrs to opt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants