Skip to content

Commit 1bb8d23

Browse files
committed
Create DeviceInterface in addStream
Signed-off-by: Dmitry Rogozhkin <[email protected]>
1 parent 4dfc022 commit 1bb8d23

File tree

7 files changed

+61
-34
lines changed

7 files changed

+61
-34
lines changed

src/torchcodec/_core/CudaDevice.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ extern "C" {
1616
namespace facebook::torchcodec {
1717
namespace {
1818

19-
bool g_cuda = registerDeviceInterface("cuda", [](const std::string& device) {
20-
return new CudaDevice(device);
21-
});
19+
bool g_cuda = registerDeviceInterface(
20+
torch::kCUDA,
21+
[](const torch::Device& device) { return new CudaDevice(device); });
2222

2323
// We reuse cuda contexts across VideoDeoder instances. This is because
2424
// creating a cuda context is expensive. The cache mechanism is as follows:
@@ -164,7 +164,7 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
164164
}
165165
} // namespace
166166

167-
CudaDevice::CudaDevice(const std::string& device) : DeviceInterface(device) {
167+
CudaDevice::CudaDevice(const torch::Device& device) : DeviceInterface(device) {
168168
if (device_.type() != torch::kCUDA) {
169169
throw std::runtime_error("Unsupported device: " + device_.str());
170170
}

src/torchcodec/_core/CudaDevice.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace facebook::torchcodec {
1212

1313
class CudaDevice : public DeviceInterface {
1414
public:
15-
CudaDevice(const std::string& device);
15+
CudaDevice(const torch::Device& device);
1616

1717
virtual ~CudaDevice(){};
1818

src/torchcodec/_core/DeviceInterface.cpp

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace facebook::torchcodec {
1212

1313
namespace {
1414
std::mutex g_interface_mutex;
15-
std::map<std::string, CreateDeviceInterfaceFn> g_interface_map;
15+
std::map<torch::DeviceType, CreateDeviceInterfaceFn> g_interface_map;
1616

1717
std::string getDeviceType(const std::string& device) {
1818
size_t pos = device.find(':');
@@ -25,7 +25,7 @@ std::string getDeviceType(const std::string& device) {
2525
} // namespace
2626

2727
bool registerDeviceInterface(
28-
const std::string deviceType,
28+
torch::DeviceType deviceType,
2929
CreateDeviceInterfaceFn createInterface) {
3030
std::scoped_lock lock(g_interface_mutex);
3131
TORCH_CHECK(
@@ -36,15 +36,39 @@ bool registerDeviceInterface(
3636
return true;
3737
}
3838

39-
std::unique_ptr<DeviceInterface> createDeviceInterface(
40-
const std::string device) {
39+
torch::Device createTorchDevice(const std::string device) {
4140
// TODO: remove once DeviceInterface for CPU is implemented
4241
if (device == "cpu") {
43-
return nullptr;
42+
return torch::kCPU;
4443
}
4544

4645
std::scoped_lock lock(g_interface_mutex);
4746
std::string deviceType = getDeviceType(device);
47+
TORCH_CHECK(
48+
std::find_if(
49+
g_interface_map.begin(),
50+
g_interface_map.end(),
51+
[&](const std::pair<torch::DeviceType, CreateDeviceInterfaceFn>&
52+
arg) {
53+
return device.rfind(
54+
torch::DeviceTypeName(arg.first, /*lower_case*/ true),
55+
0) == 0;
56+
}) != g_interface_map.end(),
57+
"Unsupported device: ",
58+
device);
59+
60+
return torch::Device(device);
61+
}
62+
63+
std::unique_ptr<DeviceInterface> createDeviceInterface(
64+
const torch::Device& device) {
65+
auto deviceType = device.type();
66+
// TODO: remove once DeviceInterface for CPU is implemented
67+
if (deviceType == torch::kCPU) {
68+
return nullptr;
69+
}
70+
71+
std::scoped_lock lock(g_interface_mutex);
4872
TORCH_CHECK(
4973
g_interface_map.find(deviceType) != g_interface_map.end(),
5074
"Unsupported device: ",

src/torchcodec/_core/DeviceInterface.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace facebook::torchcodec {
2626

2727
class DeviceInterface {
2828
public:
29-
DeviceInterface(const std::string& device) : device_(device) {}
29+
DeviceInterface(const torch::Device& device) : device_(device) {}
3030

3131
virtual ~DeviceInterface(){};
3232

@@ -53,13 +53,15 @@ class DeviceInterface {
5353
};
5454

5555
using CreateDeviceInterfaceFn =
56-
std::function<DeviceInterface*(const std::string& device)>;
56+
std::function<DeviceInterface*(const torch::Device& device)>;
5757

5858
bool registerDeviceInterface(
59-
const std::string deviceType,
59+
torch::DeviceType deviceType,
6060
const CreateDeviceInterfaceFn createInterface);
6161

62+
torch::Device createTorchDevice(const std::string device);
63+
6264
std::unique_ptr<DeviceInterface> createDeviceInterface(
63-
const std::string device);
65+
const torch::Device& device);
6466

6567
} // namespace facebook::torchcodec

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,9 @@ SingleStreamDecoder::SingleStreamDecoder(
9999

100100
SingleStreamDecoder::~SingleStreamDecoder() {
101101
for (auto& [streamIndex, streamInfo] : streamInfos_) {
102-
auto& device = streamInfo.videoStreamOptions.device;
103-
if (device) {
104-
device->releaseContext(streamInfo.codecContext.get());
102+
auto& deviceInterface = streamInfo.deviceInterface;
103+
if (deviceInterface) {
104+
deviceInterface->releaseContext(streamInfo.codecContext.get());
105105
}
106106
}
107107
}
@@ -391,7 +391,7 @@ torch::Tensor SingleStreamDecoder::getKeyFrameIndices() {
391391
void SingleStreamDecoder::addStream(
392392
int streamIndex,
393393
AVMediaType mediaType,
394-
DeviceInterface* device,
394+
const torch::Device& device,
395395
std::optional<int> ffmpegThreadCount) {
396396
TORCH_CHECK(
397397
activeStreamIndex_ == NO_ACTIVE_STREAM,
@@ -419,6 +419,7 @@ void SingleStreamDecoder::addStream(
419419
streamInfo.timeBase = formatContext_->streams[activeStreamIndex_]->time_base;
420420
streamInfo.stream = formatContext_->streams[activeStreamIndex_];
421421
streamInfo.avMediaType = mediaType;
422+
streamInfo.deviceInterface = createDeviceInterface(device);
422423

423424
// This should never happen, checking just to be safe.
424425
TORCH_CHECK(
@@ -430,9 +431,10 @@ void SingleStreamDecoder::addStream(
430431
// TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within
431432
// addStream() which is supposed to be generic
432433
if (mediaType == AVMEDIA_TYPE_VIDEO) {
433-
if (device) {
434+
if (streamInfo.deviceInterface) {
434435
avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream(
435-
device->findCodec(streamInfo.stream->codecpar->codec_id)
436+
streamInfo.deviceInterface
437+
->findCodec(streamInfo.stream->codecpar->codec_id)
436438
.value_or(avCodec));
437439
}
438440
}
@@ -450,8 +452,8 @@ void SingleStreamDecoder::addStream(
450452

451453
// TODO_CODE_QUALITY same as above.
452454
if (mediaType == AVMEDIA_TYPE_VIDEO) {
453-
if (device) {
454-
device->initializeContext(codecContext);
455+
if (streamInfo.deviceInterface) {
456+
streamInfo.deviceInterface->initializeContext(codecContext);
455457
}
456458
}
457459

@@ -481,7 +483,7 @@ void SingleStreamDecoder::addVideoStream(
481483
addStream(
482484
streamIndex,
483485
AVMEDIA_TYPE_VIDEO,
484-
videoStreamOptions.device.get(),
486+
videoStreamOptions.device,
485487
videoStreamOptions.ffmpegThreadCount);
486488

487489
auto& streamMetadata =
@@ -1222,11 +1224,11 @@ SingleStreamDecoder::convertAVFrameToFrameOutput(
12221224
formatContext_->streams[activeStreamIndex_]->time_base);
12231225
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
12241226
convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput);
1225-
} else if (!streamInfo.videoStreamOptions.device) {
1227+
} else if (!streamInfo.deviceInterface) {
12261228
convertAVFrameToFrameOutputOnCPU(
12271229
avFrame, frameOutput, preAllocatedOutputTensor);
1228-
} else if (streamInfo.videoStreamOptions.device) {
1229-
streamInfo.videoStreamOptions.device->convertAVFrameToFrameOutput(
1230+
} else if (streamInfo.deviceInterface) {
1231+
streamInfo.deviceInterface->convertAVFrameToFrameOutput(
12301232
streamInfo.videoStreamOptions,
12311233
avFrame,
12321234
frameOutput,
@@ -1569,10 +1571,8 @@ SingleStreamDecoder::FrameBatchOutput::FrameBatchOutput(
15691571
videoStreamOptions, streamMetadata);
15701572
int height = frameDims.height;
15711573
int width = frameDims.width;
1572-
torch::Device device = (videoStreamOptions.device)
1573-
? videoStreamOptions.device->device()
1574-
: torch::kCPU;
1575-
data = allocateEmptyHWCTensor(height, width, device, numFrames);
1574+
data = allocateEmptyHWCTensor(
1575+
height, width, videoStreamOptions.device, numFrames);
15761576
}
15771577

15781578
torch::Tensor allocateEmptyHWCTensor(

src/torchcodec/_core/SingleStreamDecoder.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class SingleStreamDecoder {
139139
std::optional<int> height;
140140
std::optional<ColorConversionLibrary> colorConversionLibrary;
141141
// By default we use CPU for decoding for both C++ and python users.
142-
std::shared_ptr<DeviceInterface> device;
142+
torch::Device device = torch::kCPU;
143143
};
144144

145145
struct AudioStreamOptions {
@@ -358,6 +358,8 @@ class SingleStreamDecoder {
358358
// Used to know whether a new FilterGraphContext or UniqueSwsContext should
359359
// be created before decoding a new frame.
360360
DecodedFrameContext prevFrameContext;
361+
362+
std::unique_ptr<DeviceInterface> deviceInterface;
361363
};
362364

363365
// --------------------------------------------------------------------------
@@ -460,7 +462,7 @@ class SingleStreamDecoder {
460462
void addStream(
461463
int streamIndex,
462464
AVMediaType mediaType,
463-
DeviceInterface* device = nullptr,
465+
const torch::Device& device = torch::kCPU,
464466
std::optional<int> ffmpegThreadCount = std::nullopt);
465467

466468
// Returns the "best" stream index for a given media type. The "best" is

src/torchcodec/_core/custom_ops.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,7 @@ void _add_video_stream(
239239
}
240240
}
241241
if (device.has_value()) {
242-
videoStreamOptions.device =
243-
createDeviceInterface(std::string(device.value()));
242+
videoStreamOptions.device = createTorchDevice(std::string(device.value()));
244243
}
245244

246245
auto videoDecoder = unwrapTensorToGetDecoder(decoder);

0 commit comments

Comments
 (0)