@@ -605,6 +605,22 @@ Constants
605
605
606
606
.. attribute :: LAZY_ENABLE_PEER_ACCESS
607
607
608
+ .. class :: capture_mode
609
+
610
+ CUDA 10 and newer.
611
+
612
+ .. attribute :: GLOBAL
613
+ .. attribute :: THREAD_LOCAL
614
+ .. attribute :: RELAXED
615
+
616
+ .. class :: capture_status
617
+
618
+ CUDA 10 and newer.
619
+
620
+ .. attribute :: NONE
621
+ .. attribute :: ACTIVE
622
+ .. attribute :: INVALIDATED
623
+
608
624
609
625
Graphics-related constants
610
626
^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -845,6 +861,43 @@ Concurrency and Streams
845
861
846
862
.. versionadded :: 2011.1
847
863
864
+ .. method :: begin_capture(capture_mode=capture_mode.GLOBAL)
865
+
866
+ Begins graph stream capture on a stream.
867
+
868
+ When a stream is in capture mode, all operations pushed into the stream
869
+ will not be executed, but will instead be captured into a graph.
870
+
871
+ :arg capture_mode: A :class: `capture_mode ` specifying mode for capturing graph.
872
+
873
+ CUDA 10 and above.
874
+
875
+ .. method :: end_capture()
876
+
877
+ Ends stream capture and returns a :class: `Graph ` object.
878
+
879
+ CUDA 10 and above.
880
+
881
+ .. method :: get_capture_info_v2()
882
+
883
+ Query a stream's capture state.
884
+
885
+ Return a :class: `tuple ` of (:class: `capture_status ` capture status, :class: `int ` id for the capture sequence,
886
+ :class: `Graph ` the graph being captured into, a :class: `list ` of :class: `GraphNode ` specifying set of nodes the
887
+ next node to be captured in the stream will depend on)
888
+
889
+ CUDA 10 and above.
890
+
891
+ .. method :: update_capture_dependencies(dependencies, flags)
892
+
893
+ Modifies the dependency set of a capturing stream.
894
+ The dependency set is the set of nodes that the next captured node in the stream will depend on.
895
+
896
+ :arg dependencies: A :class: `list ` of :class: `GraphNode ` specifying the new list of dependencies.
897
+ :arg flags: A :class: `int ` controlling whether the set passed to the API is added to the existing set or replaces it.
898
+
899
+ CUDA 11.3 and above.
900
+
848
901
.. class :: Event(flags=0)
849
902
850
903
An event is a temporal 'marker' in a :class: `Stream ` that allows taking the time
@@ -895,6 +948,78 @@ Concurrency and Streams
895
948
896
949
.. versionadded: 2011.2
897
950
951
+ CUDAGraphs
952
+ ----------
953
+
954
+ CUDA 10.0 and above
955
+
956
+ Launching a simple kernel using CUDAGraphs
957
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
958
+
959
+ .. literalinclude :: ../examples/cudagraph_kernel.py
960
+
961
+ .. class :: GraphNode
962
+
963
+ An object representing a node on :class: `Graph `.
964
+
965
+ Wraps `cuGraphNode <https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1gc72514a94dacc85ed0617f979211079c> `
966
+
967
+ .. class :: GraphExec
968
+
969
+ An executable graph to be launched on a stream.
970
+
971
+ Wraps `cuGraphExec <https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1gf0abeceeaa9f0a39592fe36a538ea1f0 >`_
972
+
973
+ .. method :: launch(stream_py=None)
974
+
975
+ Launches an executable graph in a stream.
976
+
977
+ :arg stream_py: :class: `Stream ` object specifying device stream.
978
+ Will use default stream if *stream_py * is None.
979
+
980
+ .. method :: kernel_node_set_params(*args, kernel_node, func=None, block=(), grid=(), shared_mem_bytes=0)
981
+
982
+ Sets a kernel node's parameters. Refer to :meth: `add_kernel_node ` for argument specifications.
983
+
984
+ .. class :: Graph()
985
+
986
+ A cudagraph is a data dependency graph meant to
987
+ serve as an alternative to :class: `Stream `.
988
+
989
+ Wraps `cuGraph <https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1g69f555c38df5b3fa1ed25efef794739a> `
990
+
991
+ .. method :: add_kernel_node(*args, func, block, grid=(1, ), dependencies=[], shared_mem_bytes=0)
992
+
993
+ Returns and adds a :class: `GraphNode ` object specifying
994
+ kernel node to the graph.
995
+
996
+ Will be placed at the root of the graph if dependencies
997
+ are not specified.
998
+
999
+ :arg args: *arg1 * through *argn * are the positional C arguments to the kernel.
1000
+ See :meth: `Function.__call__ ` for more argument details.
1001
+
1002
+ :arg func: a :class: `Function`object specifying kernel function.
1003
+
1004
+ :arg block: a :class:`tuple ` of up to three integer entries specifying the number
1005
+ of thread blocks to launch, as a multi-dimensional grid.
1006
+
1007
+ :arg grid: a :class: `tuple ` of up to three integer entries specifying the grid configuration.
1008
+
1009
+ :arg dependencies: A :class: `list ` of :class: `GraphNode ` objects specifying dependency nodes.
1010
+
1011
+ :arg shared_mem_bytes: A :class: `int ` specifying size of shared memory.
1012
+
1013
+ .. method :: instantiate()
1014
+
1015
+ Returns and instantiates a :class: `GraphExec ` object.
1016
+
1017
+ .. method :: debug_dot_print(path)
1018
+
1019
+ Returns a DOT file describing graph structure at specifed path.
1020
+
1021
+ :arg path: String specifying path for saving DOT file.
1022
+
898
1023
Memory
899
1024
------
900
1025
0 commit comments