Skip to content

Commit 3dd6110

Browse files
committed
Support override-tensors parameter
ggml-org/llama.cpp#11397
1 parent 2a56c95 commit 3dd6110

10 files changed

+260
-6
lines changed

LLama.Web/Common/ModelOptions.cs

+3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ public class ModelOptions
2626
/// <inheritdoc />
2727
public GPUSplitMode? SplitMode { get; set; }
2828

29+
/// <inheritdoc />
30+
public List<TensorBufferOverride> TensorBufferOverrides { get; set; } = new();
31+
2932
/// <inheritdoc />
3033
public int GpuLayerCount { get; set; } = 20;
3134

LLama/Abstractions/IModelParams.cs

+6
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ public interface IModelParams
3838
/// </summary>
3939
GPUSplitMode? SplitMode { get; }
4040

41+
/// <summary>
42+
/// Buffer type overrides for specific tensor patterns, allowing you to specify hardware devices to use for individual tensors or sets of tensors.
43+
/// Equivalent to --override-tensor or -ot on the llama.cpp command line or tensor_buft_overrides internally.
44+
/// </summary>
45+
List<TensorBufferOverride> TensorBufferOverrides { get; }
46+
4147
/// <summary>
4248
/// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
4349
/// </summary>
+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using System;
2+
3+
namespace LLama.Abstractions
4+
{
5+
/// <summary>
6+
/// Represents a mapping between a tensor name pattern and a specific buffer type
7+
/// </summary>
8+
public class TensorBufferOverride
9+
{
10+
/// <summary>
11+
/// Pattern to match tensor names. This is a regular expression. You can check the tensor names via the model.Metadata.
12+
/// </summary>
13+
public string Pattern { get; set; }
14+
15+
/// <summary>
16+
/// Buffer type to use for matching tensors. Examples: CPU, GPU0, GPU1
17+
/// </summary>
18+
public string BufferType { get; set; }
19+
20+
/// <summary>
21+
/// Creates a new tensor buffer override
22+
/// </summary>
23+
/// <param name="pattern">Pattern to match tensor names</param>
24+
/// <param name="bufferType">Buffer type to use for matching tensors</param>
25+
public TensorBufferOverride(string pattern, string bufferType)
26+
{
27+
if (string.IsNullOrEmpty(pattern))
28+
throw new ArgumentException("Pattern cannot be null or empty", nameof(pattern));
29+
if (string.IsNullOrEmpty(bufferType))
30+
throw new ArgumentException("Buffer type cannot be null or empty", nameof(bufferType));
31+
32+
Pattern = pattern;
33+
BufferType = bufferType;
34+
}
35+
}
36+
}

LLama/Common/ModelParams.cs

+3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ public record ModelParams
2121
/// <inheritdoc />
2222
public GPUSplitMode? SplitMode { get; set; }
2323

24+
/// <inheritdoc />
25+
public List<TensorBufferOverride> TensorBufferOverrides { get; set; } = new();
26+
2427
/// <inheritdoc />
2528
public int GpuLayerCount { get; set; } = 20;
2629

LLama/Extensions/IModelParamsExtensions.cs

+15
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ namespace LLama.Extensions;
1111
/// </summary>
1212
public static class IModelParamsExtensions
1313
{
14+
private static LLamaTensorBufferOverrideHelper bufferOverrideHelper = new();
15+
1416
/// <summary>
1517
/// Convert the given `IModelParams` into a `LLamaModelParams`
1618
/// </summary>
@@ -45,6 +47,19 @@ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLam
4547
result.tensor_split = (float*)disposer.Add(@params.TensorSplits.Pin()).Pointer;
4648
}
4749

50+
// Add tensor buffer overrides, if any
51+
if (@params.TensorBufferOverrides.Count > 0)
52+
{
53+
disposer.Add(bufferOverrideHelper);
54+
55+
foreach (var tensorOverride in @params.TensorBufferOverrides)
56+
{
57+
bufferOverrideHelper.AddOverride(tensorOverride.Pattern, tensorOverride.BufferType);
58+
}
59+
60+
bufferOverrideHelper.ApplyToModelParams(ref result);
61+
}
62+
4863
if (@params.MetadataOverrides.Count == 0)
4964
{
5065
unsafe

LLama/Native/LLamaModelParams.cs

+7-6
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@ public unsafe struct LLamaModelParams
1212
/// NULL-terminated list of devices to use for offloading (if NULL, all available devices are used)
1313
/// todo: add support for llama_model_params.devices
1414
/// </summary>
15-
private IntPtr devices;
15+
private IntPtr devices;
1616

17-
// NULL-terminated list of buffer types to use for tensors that match a pattern
18-
// actual type: llama_model_tensor_buft_override*
19-
// todo: add support for tensor_buft_overrides
20-
private IntPtr tensor_buft_overrides;
17+
/// <summary>
18+
/// NULL-terminated list of buffer types to use for tensors that match a pattern
19+
/// actual type: llama_model_tensor_buft_override*
20+
/// </summary>
21+
public IntPtr tensor_buft_overrides;
2122

2223
/// <summary>
2324
/// // number of layers to store in VRAM
@@ -111,6 +112,6 @@ public static LLamaModelParams Default()
111112

112113
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
113114
static extern LLamaModelParams llama_model_default_params();
114-
}
115+
}
115116
}
116117
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using System;
2+
3+
namespace LLama.Native
4+
{
5+
/// <summary>
6+
/// Represents a mapping between a tensor name pattern and a backend buffer type<br/>
7+
/// Original type: llama_model_tensor_buft_override
8+
/// </summary>
9+
[StructLayout(LayoutKind.Sequential)]
10+
public struct LLamaModelTensorBufferOverride
11+
{
12+
/// <summary>
13+
/// Tensor name pattern to match
14+
/// </summary>
15+
public IntPtr Pattern;
16+
17+
/// <summary>
18+
/// Backend buffer type to use for matching tensors, as obtained via ggml_backend_dev_buffer_type
19+
/// </summary>
20+
public IntPtr BufferType;
21+
}
22+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace LLama.Native
6+
{
7+
/// <summary>
8+
/// Helper for creating and managing tensor buffer overrides
9+
/// </summary>
10+
public class LLamaTensorBufferOverrideHelper : IDisposable
11+
{
12+
private readonly List<IntPtr> _allocatedMemory = new();
13+
private readonly List<LLamaModelTensorBufferOverride> _overrides = new();
14+
private IntPtr _overrideArray = IntPtr.Zero;
15+
private readonly Dictionary<string, IntPtr> _bufferTypeCache = new();
16+
17+
/// <summary>
18+
/// Get all available buffer types
19+
/// </summary>
20+
/// <returns>Dictionary mapping buffer type names to their handles</returns>
21+
public Dictionary<string, IntPtr> GetAvailableBufferTypes()
22+
{
23+
var result = new Dictionary<string, IntPtr>();
24+
25+
nuint count = NativeApi.ggml_backend_dev_count();
26+
for (nuint i = 0; i < count; i++)
27+
{
28+
IntPtr dev = NativeApi.ggml_backend_dev_get(i);
29+
IntPtr buft = NativeApi.ggml_backend_dev_buffer_type(dev);
30+
31+
if (buft != IntPtr.Zero)
32+
{
33+
IntPtr namePtr = NativeApi.ggml_backend_buft_name(buft);
34+
string name = Marshal.PtrToStringAnsi(namePtr) ?? string.Empty;
35+
36+
if (!string.IsNullOrEmpty(name))
37+
{
38+
result[name] = buft;
39+
_bufferTypeCache[name] = buft;
40+
}
41+
}
42+
}
43+
44+
return result;
45+
}
46+
47+
/// <summary>
48+
/// Add a tensor buffer override
49+
/// </summary>
50+
/// <param name="pattern">Tensor name pattern to match</param>
51+
/// <param name="bufferTypeName">Name of the buffer type to use</param>
52+
/// <returns>True if the override was added successfully</returns>
53+
public bool AddOverride(string pattern, string bufferTypeName)
54+
{
55+
if (string.IsNullOrEmpty(pattern) || string.IsNullOrEmpty(bufferTypeName))
56+
return false;
57+
58+
// Get all buffer types if cache is empty
59+
if (_bufferTypeCache.Count == 0)
60+
{
61+
GetAvailableBufferTypes();
62+
}
63+
64+
// Check if we have this buffer type
65+
if (!_bufferTypeCache.TryGetValue(bufferTypeName, out IntPtr bufferType))
66+
return false;
67+
68+
// Allocate memory for the pattern string and keep track of it
69+
byte[] patternBytes = Encoding.UTF8.GetBytes(pattern + "\0");
70+
IntPtr patternPtr = Marshal.AllocHGlobal(patternBytes.Length);
71+
Marshal.Copy(patternBytes, 0, patternPtr, patternBytes.Length);
72+
_allocatedMemory.Add(patternPtr);
73+
74+
// Create the override
75+
var @override = new LLamaModelTensorBufferOverride
76+
{
77+
Pattern = patternPtr,
78+
BufferType = bufferType
79+
};
80+
81+
_overrides.Add(@override);
82+
return true;
83+
}
84+
85+
/// <summary>
86+
/// Apply the overrides to model parameters
87+
/// </summary>
88+
/// <param name="modelParams">Model parameters to update</param>
89+
public unsafe void ApplyToModelParams(ref LLamaModelParams modelParams)
90+
{
91+
if (_overrides.Count == 0)
92+
{
93+
modelParams.tensor_buft_overrides = IntPtr.Zero;
94+
return;
95+
}
96+
97+
// Free previous array if it exists
98+
if (_overrideArray != IntPtr.Zero)
99+
{
100+
Marshal.FreeHGlobal(_overrideArray);
101+
}
102+
103+
// Allocate memory for the array + null terminator
104+
int size = Marshal.SizeOf<LLamaModelTensorBufferOverride>() * (_overrides.Count + 1);
105+
_overrideArray = Marshal.AllocHGlobal(size);
106+
_allocatedMemory.Add(_overrideArray);
107+
108+
// Copy overrides to array
109+
for (int i = 0; i < _overrides.Count; i++)
110+
{
111+
IntPtr elemPtr = IntPtr.Add(_overrideArray, i * Marshal.SizeOf<LLamaModelTensorBufferOverride>());
112+
Marshal.StructureToPtr(_overrides[i], elemPtr, false);
113+
}
114+
115+
// Add null terminator
116+
IntPtr nullTermPtr = IntPtr.Add(_overrideArray, _overrides.Count * Marshal.SizeOf<LLamaModelTensorBufferOverride>());
117+
Marshal.StructureToPtr(new LLamaModelTensorBufferOverride { Pattern = IntPtr.Zero, BufferType = IntPtr.Zero }, nullTermPtr, false);
118+
119+
// Update model params
120+
modelParams.tensor_buft_overrides = _overrideArray;
121+
}
122+
123+
/// <inheritdoc />
124+
public void Dispose()
125+
{
126+
foreach (IntPtr ptr in _allocatedMemory)
127+
{
128+
Marshal.FreeHGlobal(ptr);
129+
}
130+
_allocatedMemory.Clear();
131+
_overrides.Clear();
132+
_overrideArray = IntPtr.Zero;
133+
}
134+
}
135+
}

LLama/Native/NativeApi.Load.cs

+2
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ private static void SetDllImportResolver()
107107

108108
internal const string libraryName = "llama";
109109
internal const string llavaLibraryName = "llava_shared";
110+
internal const string ggmlLibraryName = "ggml";
111+
internal const string ggmlBaseLibraryName = "ggml-base";
110112

111113
private static INativeLibrary? _loadedLLamaLibrary = null;
112114
private static INativeLibrary? _loadedLLavaLibrary = null;

LLama/Native/NativeApi.cs

+31
Original file line numberDiff line numberDiff line change
@@ -439,5 +439,36 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
439439
// it would expose the raw pointer to the model, without properly wrapping it in a SafeLLamaModelHandle.
440440
//[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
441441
//public static void llama_model* llama_get_model(SafeLLamaContextHandle ctx);
442+
443+
/// <summary>
444+
/// Get the number of available backend devices
445+
/// </summary>
446+
/// <returns>Count of available backend devices</returns>
447+
[DllImport(ggmlLibraryName, CallingConvention = CallingConvention.Cdecl)]
448+
public static extern nuint ggml_backend_dev_count();
449+
450+
/// <summary>
451+
/// Get a backend device by index
452+
/// </summary>
453+
/// <param name="i">Device index</param>
454+
/// <returns>Pointer to the backend device</returns>
455+
[DllImport(ggmlLibraryName, CallingConvention = CallingConvention.Cdecl)]
456+
public static extern IntPtr ggml_backend_dev_get(nuint i);
457+
458+
/// <summary>
459+
/// Get the buffer type for a backend device
460+
/// </summary>
461+
/// <param name="dev">Backend device pointer</param>
462+
/// <returns>Pointer to the buffer type</returns>
463+
[DllImport(ggmlBaseLibraryName, CallingConvention = CallingConvention.Cdecl)]
464+
public static extern IntPtr ggml_backend_dev_buffer_type(IntPtr dev);
465+
466+
/// <summary>
467+
/// Get the name of a buffer type
468+
/// </summary>
469+
/// <param name="buft">Buffer type pointer</param>
470+
/// <returns>Name of the buffer type</returns>
471+
[DllImport(ggmlBaseLibraryName, CallingConvention = CallingConvention.Cdecl)]
472+
public static extern IntPtr ggml_backend_buft_name(IntPtr buft);
442473
}
443474
}

0 commit comments

Comments
 (0)