Skip to content
Snippets Groups Projects
Commit 43107207 authored by Edward Schwartz's avatar Edward Schwartz Committed by TensorFlower Gardener
Browse files

Create new PjRt GPU client with remote devices after coordination service agent is available

This does not modify the previous PjRt GPU client. The test environment case where multiple threads are used to simulate multiple workers is supported.

For the PjRt GPU MultiWorkerMirroredStrategy (MWMS) case where there are multiple workers each with a GPU or GPUs, the symptom of a client without this fix is a error message like the following where the number at the end is the ID of the first remote device, e.g. 8 when there are eight local GPUs with IDs 0..7. (Another example is 1 when there is one local GPU with ID 0.)
`INVALID_ARGUMENT: No matching device found for device_id 8`

Note that while the primary purpose of `BaseGPUDeviceFactory::CreateDevices` is to do one-time initialization, it is often called multiple times. In the typical production MWMS case, it is called both when creating a TF Context and when creating GRPC servers when enabling collectives. In the unit test added by this CL (which uses two GPUs), it is first called when a TF Context is created when starting up the test environment (with both GPUS) and then two worker TF processes are created which both call it twice during MWMS startup (each with the one GPU assigned to the process). Other test cases have different patterns.

PiperOrigin-RevId: 611621382
parent 590adf88
No related merge requests found
Showing
with 737 additions and 54 deletions
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment