Skip to content

Conversation

@kai-tub
Copy link
Contributor

@kai-tub kai-tub commented Mar 25, 2020

Fixes #843

Description:
Providing a Wrapper for the ThroughputBenchmark

Check list:

  • New tests are added (if a new feature is added)
  • New doc strings: description and/or example code are in RST format
  • Documentation is updated (if required)

TODOs:

  • Wrong representation in handlers.html, shows: torch.nn.modules.module.Module instead of Union[torch.nn.Module, torch.jit.ScriptModule]
  • Discuss how the devices should be used.

How should we work with the devices? Normally, everything should run on the CPU, as it would be quite unusual to use GPUs for inferencing. Should we make sure everything is moved to the CPU (model and data)?
If we allow different devices, then I think a prepare_batch function is necessary, for the user to be able to move the data manually to the correct device.
What are your thoughts?

Copy link
Contributor

@justusschock justusschock left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a few comments for discussion :)

from torch.utils.throughput_benchmark import ExecutionStats # for linting


# TODO: Discuss device implications
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by device implications?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, sorry for that cryptic todo. I was referring to my comment from my previous post.
As in: How should we handle the different devices. Should we support them? Should we assume everything always runs on CPU? Should we move everything to CPU for easier use?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kai-tub I'm trying to play a bit with the implementation and unfortunately device implication and evaluator as an input of attach look very misleading. For example,

from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
from torchvision.transforms import ToTensor
from torchvision.models.resnet import resnet101
from ignite.engine import create_supervised_evaluator

device = "cuda"
model = resnet101().to(device)
dataset = FakeData(num_classes=1000, transform=ToTensor())
dataloader = DataLoader(dataset, batch_size=32, num_workers=10, shuffle=True)
evaluator = create_supervised_evaluator(model, device=device, non_blocking=True)
# evaluator's update_function passes batches to specified device etc

throughput_benchmark = ThroughputBenchmarkWrapper(model)
max_batches = 10
with throughput_benchmark.attach(evaluator, max_batches=max_batches) as evaluator_with_benchmark:
    evaluator_with_benchmark.run(dataloader)

# however here everything crashes as model is on GPU but dataloader's batches are on CPU

Seems like throughput_benchmark should behave like an Engine ?

device = "cuda"
model = resnet101().to(device)
dataset = FakeData(num_classes=1000, transform=ToTensor())
dataloader = DataLoader(dataset, batch_size=32, num_workers=10, shuffle=True)

throughput_benchmark = create_throughput_benchmark(model, device=device, non_blocking=True, prepare_batch=None)
max_batches = 10
stats = throughput_benchmark.run(dataloader, max_batches=max_batches)

What do you think ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I like the idea of using it more like an engine more. In my mind, the user will not run the benchmark after every iteration/epoch, as the performance shouldn't change when the weights change.
More like a separate step of the pipeline, where the performance of different models are quickly evaluated on their own, maybe with different types of input data.

benchmark.

Args:
model (Union[torch.nn.Module, torch.jit.ScriptModule]): model which will
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just torch.nn.Module should be sufficient as a type I guess. I think, a ScriptModule is a subclass of the usual model, IIRC

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are correct. I will change it in my next commit. :)

import contextlib
from typing import Callable, Union
from ignite.engine import Events, Engine
from torch.utils import ThroughputBenchmark
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you wrap this in a way, that you first check the pytorch version if it is recent enough? This will probably result in a more understandable error message then a plain import error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you only referring to from torch.utils?
Should I simply use a try except(ImportError) to get the ThroughputBenchmark, or should I inspect the torch.__version__ and raise a different error, if it is too low?

In the other contrib files I've seen that the specific module is loaded lazily and then a runtime error is raised if it doesn't exist. I could also do that, but then I couldn't annotate the return of def stats.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've created a typed stub for ExecutionStats so that I don't have to import it manually.
The check if I can access torch.utils.ThroughputBenchmark is now done lazily and is more similar to the other approaches.

self._detach(engine)

@property
def stats(self) -> ExecutionStats:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can promote self._stats to a public attribute without property wrapping ?
Thus we can avoid redefining ExecutionStats and this typing issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, I personally would still favor writing it like:
self.stats: Optional[ExecutionStats]
I am a big fan of type hinting and abusing auto-complete as much as possible ;)

But I agree that the current approach is a sup-bar solution.
In the end, either the class has to be tried to be imported, a stub has to be created or no typing support at all.

But if you like it more that way, I will rewrite the code and update the code to link to the specific class.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a big fan of type hinting and abusing auto-complete as much as possible ;)

@kai-tub that's cool ! So, can we find a solution to avoid the stub ExecutionStats, keep our try except in __init__ to import pytorch's ThroughputBenchmark and define the type of self.stats as you propose ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my new commit. It looks better, but I won't guarantee that all auto-complete tools will work with this annotation style. ;)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally missed that Python 3.5 is still supported. Yeah, then the variable annotation won't work, so we are back where we started.
Instead of writing a typing stub we could try to import this specific module and catch the import error if present. Since it is never used, it shouldn't cause any problems but it would be a confusing line to read

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for not committing code, I am currently changing everything to use the engine design.
What I am currently thinking of (combining both approaches) is something like this:

from typing import Callable, Optional
import contextlib

import torch

from ignite.engine import Events, Engine

try:
    from torch.utils.throughput_benchmark import ThroughputBenchmark
    from torch.utils.throughput_benchmark import ExecutionStats  # for typing
except ImportError:
    raise RuntimeError(...)

class ThroughputBenchmarkWrapper:

     @property
         def stats(self) -> ExecutionStats:

The reason why I didn't write it like that to start out with, was that I tried the user to be able to import the code without raising an exception. To be able to provide typing support, not raising the exception at import time and not using a stub I've proposed the weird try and catch that you just wrote. :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for not committing code, I am currently changing everything to use the engine design.

No problems, thanks a lot for working on that !

The issue with

from typing import Callable, Optional
import contextlib

import torch

from ignite.engine import Events, Engine

try:
    from torch.utils.throughput_benchmark import ThroughputBenchmark
    from torch.utils.throughput_benchmark import ExecutionStats  # for typing
except ImportError:
    raise RuntimeError(...)


class ThroughputBenchmarkWrapper:
    ...

as you say, users with pytorch 1.1.0 wont be able to use ignite.contrib module at all as it imports ThroughputBenchmarkWrapper. So, this solution is not possible.
That's why property seems like to be unfeasible with all our limitations...

Anyway, currently, contrib module is not using typing at all and even core module wont pass mypy tests. IMO, the solution is either use typing in the comment or remove the typing (which is :( ) or something else if possible to satisfy all our limitations...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh. I totally missed that because of __init__.py the exception will be raised as soon as ignite.contrib is loaded... Ok, then I agree and will use the typing in the comment. :)

I've pushed the redesigned engine version, without any tests or documentation, just so it is easier for us to discuss the code. This version can be called exactly as you previously proposed, but the code feels... Hacky. I've inherited from Engine to be able to make changes to how the run function is executed and to be able to initialize the other parameters.

But I am unsure on how to rewrite the code.

Copy link
Collaborator

@vfdev-5 vfdev-5 Mar 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I understand your feeling. I feel the same everytime while inheriting directly from Engine and while setup some internal handlers. Maybe we do not need to create any class like BenchmarkEngine(Engine) and just have create_throughput_benchmark that returns specifically configured instance of Engine ?

PS: for instance, I didn't try to do that by myself, so I do not see all implications of that...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.
So this is an alternative idea:
I've used a singleton wrapper, to only create the benchmark once, instead of having
to wrap the complete function again to keep the state.

class SingletonDecorator:
    def __init__(self, class_):
        self.class_ = class_
        self.instance = None

    def __call__(self, *args, **kwargs):
        if self.instance is None:
            self.instance = self.class_(*args, **kwargs)
        return self.instance


def create_throughput_benchmark(
    model: torch.nn.Module,
    device: Optional[Union[str, torch.device]] = None,
    non_blocking: bool = False,
    prepare_input: Callable = _prepare_input,
    max_batches: int = 10,
    num_calling_threads: int = 1,
    num_warmup_iters: int = 10,
    num_iters: int = 100,
) -> Engine:
    try:
        from torch.utils.throughput_benchmark import ThroughputBenchmark
    except ImportError:
        raise RuntimeError("This method requires at least pytorch version 1.2.0")

    ThroughputBenchmark = SingletonDecorator(ThroughputBenchmark)

    def _run(engine: Engine, batch: Sequence[torch.Tensor]) -> None:
        benchmark = ThroughputBenchmark(model)
        # In this loop nothing happens besides the loading of the input data
        model.eval()  # Should the user expect the automatic setting to eval?
        input_data = _prepare_input(batch, device=device, non_blocking=non_blocking)
        benchmark.add_input(input_data)
        if engine.state.iteration == max_batches:
            engine.terminate()
            engine.results = benchmark.benchmark(
                num_calling_threads=num_calling_threads, num_warmup_iters=num_warmup_iters, num_iters=num_iters
            )

    engine = Engine(_run)

    # see PR https://github.com/pytorch/ignite/pull/835
    if device is not None:

        @engine.on(Events.STARTED)
        def move_device(engine):
            model.to(device)

    return engine

Now:

  • The user has to specify the max_batches, num_calling_threads etc. when creating the engine and not when calling the run function, as this is now a vanilla run function.
  • Kept the internal handlers to a minimum and moved the logic to _run function which will be called from the engine.
  • Problem: I am not sure how the user should access the results... What would be the easiest way? I've played around with the Engine a bit, but I cannot change the internal state representation. I've also thought about somehow integrating it into a Metric, so it could be saved there, but that seems to be overkill.

I am sorry, but I need some help to understand how the result could be integrated into the current engine.

PS: Do you like it better, when I paste these "pseudo-examples" here or would you prefer them on my branch? I am hesitating because I don't want to unnecessarily spam the CI pipeline for some random thoughts.

Kai Clasen and others added 2 commits March 25, 2020 23:18
Promoted stats property to attribute execution_stats.
ExecutionStats class will now be lazily loaded and type annotated.
May depend on specific implementation,
if this annotation is supported by IDE.

Also reordered the libraries as desired
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Mar 26, 2020

@kai-tub let's discuss here.

PS: Do you like it better, when I paste these "pseudo-examples" here or would you prefer them on my branch? I am hesitating because I don't want to unnecessarily spam the CI pipeline for some random thoughts.

Yes, we can exchange the code here too and avoid useless CI builds.

Problem: I am not sure how the user should access the results... What would be the easiest way? I've played around with the Engine a bit, but I cannot change the internal state representation. I've also thought about somehow integrating it into a Metric, so it could be saved there, but that seems to be overkill.

We can put it into engine.state.stats ?

What is the purpose of SingletonDecorator ? We can define benchmark outside of _run:

def create_throughput_benchmark(
    model: torch.nn.Module,
    device: Optional[Union[str, torch.device]] = None,
    non_blocking: bool = False,
    prepare_input: Callable = _prepare_input,
    max_batches: int = 10,
    num_calling_threads: int = 1,
    num_warmup_iters: int = 10,
    num_iters: int = 100,
) -> Engine:

    benchmark = ThroughputBenchmark(model)
    model.eval()  # Should the user expect the automatic setting to eval?

    def _run(engine: Engine, batch: Sequence[torch.Tensor]) -> None:
        
        input_data = _prepare_input(batch, device=device, non_blocking=non_blocking)
        benchmark.add_input(input_data)
        if engine.state.iteration == max_batches:
            engine.terminate()
            engine.results = benchmark.benchmark(
                num_calling_threads=num_calling_threads, num_warmup_iters=num_warmup_iters, num_iters=num_iters
            )

?

@kai-tub
Copy link
Contributor Author

kai-tub commented Mar 26, 2020

We can put it into engine.state.stats ?

Ok, I don't know what I've tested. I thought I tried to write to engine.state and got an error.
Yes, I think that should be a good solution. :)

What is the purpose of SingletonDecorator ?

I wasn't sure if the to() would've changed the model in the benchmark, but it seems like it works.

from typing import Callable, Optional, Sequence, Union

import torch

from ignite.engine import Events, Engine
from ignite.utils import convert_tensor


def _prepare_input(
    batch: Sequence[torch.Tensor], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False
):
    """
    Prepare batch for adding to benchmark: pass to a device with options.
    """
    x, _ = batch
    return convert_tensor(x, device=device, non_blocking=non_blocking)


def create_throughput_benchmark(
    model: torch.nn.Module,
    device: Optional[Union[str, torch.device]] = None,
    non_blocking: bool = False,
    prepare_input: Callable = _prepare_input,
    max_batches: int = 10,
    num_calling_threads: int = 1,
    num_warmup_iters: int = 10,
    num_iters: int = 100,
) -> Engine:
    try:
        from torch.utils.throughput_benchmark import ThroughputBenchmark
    except ImportError:
        raise RuntimeError("This method requires at least pytorch version 1.2.0")

    benchmark = ThroughputBenchmark(model)
    model.eval()  # Should the user expect the automatic setting to eval?

    def _run(engine: Engine, batch: Sequence[torch.Tensor]) -> None:
        input_data = _prepare_input(batch, device=device, non_blocking=non_blocking)
        benchmark.add_input(input_data)
        if engine.state.iteration == max_batches:
            engine.terminate()
            engine.state.iteration
            engine.state.stats = benchmark.benchmark(
                num_calling_threads=num_calling_threads, num_warmup_iters=num_warmup_iters, num_iters=num_iters
            )

    engine = Engine(_run)
    engine.state.stats = None

    # see PR https://github.com/pytorch/ignite/pull/835
    if device is not None:

        @engine.on(Events.STARTED)
        def move_device(engine):
            print("Moving device")
            model.to(device)

    return engine

Ok, if you are happy with the current "pseudo-implementation" I would start to write tests and docs after these questions:

  • Should model.eval() be called? Or should this be an argument? I see little reason why it should be train() mode, but I don't want to be too restrictive. What is you opinion?
  • In this implementation stats has no typing support and a linter will probably complain that engine.state.stats doesn't exist.
  • Is the model.to() move fine?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Mar 26, 2020

Just some remarks:

    engine = Engine(_run)
    # engine.state.stats = None  # it should be on started as state is None

        @engine.on(Events.STARTED)
        def move_device(engine):
            engine.state.stats = None
           # see PR https://github.com/pytorch/ignite/pull/835
           if device is not None:
                 print("Moving device")
                 model.to(device)

If you would like to test it on 1 GPU, you can use google colab ?

Should model.eval() be called? Or should this be an argument? I see little reason why it should be train() mode, but I don't want to be too restrictive. What is you opinion?

We can check both options if it makes sense. Then, probably, we can remove it.

I'll try to check this with DP, DDP on multiple GPUs and comment out. Hopefully, we can also move on your other PR.

@kai-tub
Copy link
Contributor Author

kai-tub commented Mar 26, 2020

If you would like to test it on 1 GPU, you can use google colab ?

This is what I've just done. :)
https://colab.research.google.com/drive/1FQByhKm9zgoMVlUDeY9ahBr_oBApx0Fq

We can check both options if it makes sense. Then, probably, we can remove it.

I was thinking that somebody could have an uncommon architecture that dynamically uses different paths for the dataflow, depending on the mode. (Which would weird, but just as a thought)
I have no experience in writing and maintaining libraries, so I will just go with what you recommend. :)

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Mar 28, 2020

@kai-tub I tested this module from my side and I have to improvements to add.

  • Add optional argument to setup ProgressBar
  • Once it is set up, we can see that engine.terminate() inside processing function removes the last update of the progess bar... => terminate + run benchmark in a separate handler
  • In order to setup progress bar correctly, we need to setup engine's epoch length to max_batches. Maybe this can be better done if override run method that accepts only dataloader...
  • tested on single GPU, on 2 GPUs with DP (x3 slower that single GPU)
  • testes with DDP and seems like stats are not "syncrhonized" across participated devices... so TODO
def create_throughput_benchmark(
    model: torch.nn.Module,
    device: Optional[Union[str, torch.device]] = None,
    non_blocking: bool = False,
    prepare_input: Callable = _prepare_input,
    max_batches: int = 10,
    num_calling_threads: int = 1,
    num_warmup_iters: int = 10,
    num_iters: int = 100,
) -> Engine:
    try:
        from torch.utils.throughput_benchmark import ThroughputBenchmark
    except ImportError:
        raise RuntimeError("This method requires at least pytorch version 1.2.0")

    benchmark = ThroughputBenchmark(model)

    def _run(engine: Engine, batch: Sequence[torch.Tensor]) -> None:
        input_data = _prepare_input(batch, device=device, non_blocking=non_blocking)
        benchmark.add_input(input_data)

    engine = Engine(_run)
    engine.load_state_dict({"iteration": 0, "epoch_length": max_batches, "max_epochs": 1, "seed": 0})
    
    if dist.is_available() and dist.is_initialized() and dist.get_rank() == 0:
        ProgressBar(desc="ThroughputBenchmark").attach(engine)

    @engine.on(Events.ITERATION_COMPLETED(once=max_batches))
    def start_benchmark(_):            
        engine.terminate()
        engine.state.stats = benchmark.benchmark(
            num_calling_threads=num_calling_threads, num_warmup_iters=num_warmup_iters, num_iters=num_iters
        )
        if dist.is_available() and dist.is_initialized():
            # reduce metrics across all devices
            pass
        

    # see PR https://github.com/pytorch/ignite/pull/835
    if device is not None:

        @engine.on(Events.STARTED)
        def move_device(engine):
            engine.state.stats = None
            # see PR https://github.com/pytorch/ignite/pull/835
            print(type(model))
            if device is not None:
                print("Moving device")
                model.to(device)

    return engine

@kai-tub
Copy link
Contributor Author

kai-tub commented Apr 10, 2020

testes with DDP and seems like stats are not "syncrhonized" across participated devices... so TODO

Could you give me a hint on how this could be implemented?
After reading:

I first thought that the easiest way to implement was to use dist.reduce() over a new set of tensors parsed from the ExecutionStats class. Something like this:

def _parse_execution_stats(exec_stats) -> Dict[str, torch.Tensor]:
    """
    Parse all of the number properties into a dictionary with the key defined by the
    attribute name and the value converted to a Tensor.
    """
    return {
        attribute: torch.tensor(getattr(exec_stats, attribute)) 
        for attribute in dir(exec_stats) 
        if isinstance(getattr(exec_stats, attribute), Number)
    }

The important part is simply creating a tensor, so that I can later use dist.reduce(), but I forgot that the tensor could live on the gpu. Now I would have to somehow know the current GPU devices used in the process. And now it doesn't feel like a good solution.
Now I am back to a 'trivial' solution, for which I would let each process write the values to disk (I don't know how easy this is in a DDP case) with a random name, use dist.barrier to wait and then do the calculations for rank 0.

I am unsure on how to otherwise communicate between the different processes (especially with different backends)

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Apr 10, 2020

@kai-tub I think your approach with dist.reduce or dist.all_reduce can be OK. We can all reduce metrics such that every process has the same stats and then user decide from which rank to use them...

You can also tests this with "gloo" backend, like we do in our distrib tests on CPU:

ignite/.travis.yml

Lines 55 to 56 in 32275a8

- export WORLD_SIZE=2
- py.test --cov ignite --cov-append --cov-report term-missing --dist=each --tx $WORLD_SIZE*popen//python=python$TRAVIS_PYTHON_VERSION tests -m distributed -vvv

and https://github.com/pytorch/ignite/blob/master/tests/ignite/metrics/test_accumulation.py#L366

PS: pytest-xdist is required to run tests

@kai-tub
Copy link
Contributor Author

kai-tub commented Apr 10, 2020

I just skimmed the code, but if I would implement it similar to the provided code with: "device = local_rank" (if this is not what is happening you can ignore this and I will take a closer look tomorrow) wouldn't it impose the user to have a single Gpu (assuming no model parallel) per rank? I would assume this is the general case, but the most confusing part for me about DDP comes from the comment

setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
rank 2 uses GPUs [4, 5, 6, 7].

In https://pytorch.org/tutorials/intermediate/ddp_tutorial.html

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Apr 10, 2020

Well, let's assume for instance DDP "Multi-Process Single-GPU" where I think there is nothing to do with devices, just for each process put scalars into tensors, all_reduce, to scalars

Concerning the comments from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html:

    setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
    rank 2 uses GPUs [4, 5, 6, 7].

IMO, it is a sort of combination of DP and DDP:

          rank1    rank2
Data : [--------|--------]
Model:     DP1      DP2          x4 each

@kai-tub
Copy link
Contributor Author

kai-tub commented Apr 11, 2020

Well, let's assume for instance DDP "Multi-Process Single-GPU" where I think there is nothing to do with devices, just for each process put scalars into tensors, all_reduce, to scalars

But I think this approach only works for the gloo backend. If the user uses the nccl backend, then the tensor has to live on the GPU, as only these are supported. Also, the tensor is not allowed to live on the GPU of a different process. Then I would be required to get the "current" GPU that is used by the process. Should I simply get the device like used in this line:
https://github.com/pytorch/ignite/blob/master/tests/ignite/metrics/test_accumulation.py#L388?
Or would this be dependent on how the user initializes the process group and gpus?

Sorry for these questions, I've never used DDP before.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Apr 11, 2020

How it works DDP with nccl "Multi-Process Single-GPU" for example 2 nodes with 2 GPUs each node. We start the code as

    On the first terminal for node0, run $ python main.py --rank=0 --local-rank=0
    On the second terminal for node0 run $ python main.py --rank=1 --local-rank=1
    On the first terminal for node1, run $ python main.py --rank=2 --local-rank=0
    On the second terminal for node1 run $ python main.py --rank=3 --local-rank=1

or with launch utility: https://pytorch.org/docs/stable/distributed.html#launch-utility
In the code, we setup device to local rank GPU and initialize processing group:

dp_device_ids = [local_rank]
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl", init_method=dist_url, rank=rank, world_size=world_size)

so, in the process device=cuda corresponds to the good GPU.
Based on https://pytorch.org/tutorials/beginner/aws_distributed_training_tutorial.html

In our case we have to do something like that:

v = [12, 13, ]
t = torch.tensor(v, device=device)  # device should be the same as in create_throughput_benchmark
dist.all_reduce(t)  # -> t now should be reduced accross all devices.

Sorry for these questions, I've never used DDP before.

No problems. It's always good to learn :)

@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 7, 2020

Hey @kai-tub any updates on this PR ?

@kai-tub
Copy link
Contributor Author

kai-tub commented May 14, 2020

Oh, sorry.
I got heavily side-tracked by different projects.
I will try to work on this starting at the end of next week.

Sorry for the long outstanding PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Provide tiny wrapper over pytorch ThroughputBenchmark

3 participants