Skip to content

Feat/tensor override #1180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/_typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,6 @@ extend-exclude = [

[default.extend-words]
# Used in a comment in SafeLLamaSamplerHandle.cs, as a prefix of "hello"
teh = "hel"
teh = "hel"
# ot is the shorthand version of llama.cpp's override-tensor parameter
ot = "ot"
3 changes: 3 additions & 0 deletions LLama.Web/Common/ModelOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ public class ModelOptions
/// <inheritdoc />
public GPUSplitMode? SplitMode { get; set; }

/// <inheritdoc />
public List<TensorBufferOverride> TensorBufferOverrides { get; set; } = new();

/// <inheritdoc />
public int GpuLayerCount { get; set; } = 20;

Expand Down
6 changes: 6 additions & 0 deletions LLama/Abstractions/IModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ public interface IModelParams
/// </summary>
GPUSplitMode? SplitMode { get; }

/// <summary>
/// Buffer type overrides for specific tensor patterns, allowing you to specify hardware devices to use for individual tensors or sets of tensors.
/// Equivalent to --override-tensor or -ot on the llama.cpp command line or tensor_buft_overrides internally.
/// </summary>
List<TensorBufferOverride> TensorBufferOverrides { get; }

/// <summary>
/// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
/// </summary>
Expand Down
36 changes: 36 additions & 0 deletions LLama/Abstractions/TensorBufferOverride.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using System;

namespace LLama.Abstractions
{
/// <summary>
/// Represents a mapping between a tensor name pattern and a specific buffer type
/// </summary>
public class TensorBufferOverride
{
/// <summary>
/// Pattern to match tensor names. This is a regular expression. You can check the tensor names via the model.Metadata.
/// </summary>
public string Pattern { get; set; }

/// <summary>
/// Buffer type to use for matching tensors. Examples: CPU, GPU0, GPU1
/// </summary>
public string BufferType { get; set; }

/// <summary>
/// Creates a new tensor buffer override
/// </summary>
/// <param name="pattern">Pattern to match tensor names</param>
/// <param name="bufferType">Buffer type to use for matching tensors</param>
public TensorBufferOverride(string pattern, string bufferType)
{
if (string.IsNullOrEmpty(pattern))
throw new ArgumentException("Pattern cannot be null or empty", nameof(pattern));
if (string.IsNullOrEmpty(bufferType))
throw new ArgumentException("Buffer type cannot be null or empty", nameof(bufferType));

Pattern = pattern;
BufferType = bufferType;
}
}
}
3 changes: 3 additions & 0 deletions LLama/Common/ModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ public record ModelParams
/// <inheritdoc />
public GPUSplitMode? SplitMode { get; set; }

/// <inheritdoc />
public List<TensorBufferOverride> TensorBufferOverrides { get; set; } = new();

/// <inheritdoc />
public int GpuLayerCount { get; set; } = 20;

Expand Down
15 changes: 15 additions & 0 deletions LLama/Extensions/IModelParamsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ namespace LLama.Extensions;
/// </summary>
public static class IModelParamsExtensions
{
private static LLamaTensorBufferOverrideHelper bufferOverrideHelper = new();

/// <summary>
/// Convert the given `IModelParams` into a `LLamaModelParams`
/// </summary>
Expand Down Expand Up @@ -45,6 +47,19 @@ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLam
result.tensor_split = (float*)disposer.Add(@params.TensorSplits.Pin()).Pointer;
}

// Add tensor buffer overrides, if any
if (@params.TensorBufferOverrides.Count > 0)
{
disposer.Add(bufferOverrideHelper);

foreach (var tensorOverride in @params.TensorBufferOverrides)
{
bufferOverrideHelper.AddOverride(tensorOverride.Pattern, tensorOverride.BufferType);
}

bufferOverrideHelper.ApplyToModelParams(ref result);
}

if (@params.MetadataOverrides.Count == 0)
{
unsafe
Expand Down
6 changes: 6 additions & 0 deletions LLama/Native/LLamaModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ public unsafe struct LLamaModelParams
/// todo: add support for llama_model_params.devices
/// </summary>
private IntPtr devices;

/// <summary>
/// NULL-terminated list of buffer types to use for tensors that match a pattern
/// actual type: llama_model_tensor_buft_override*
/// </summary>
public IntPtr tensor_buft_overrides;

/// <summary>
/// // number of layers to store in VRAM
Expand Down
22 changes: 22 additions & 0 deletions LLama/Native/LLamaModelTensorBufferOverride.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using System;

namespace LLama.Native
{
/// <summary>
/// Represents a mapping between a tensor name pattern and a backend buffer type<br/>
/// Original type: llama_model_tensor_buft_override
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public struct LLamaModelTensorBufferOverride
{
/// <summary>
/// Tensor name pattern to match
/// </summary>
public IntPtr Pattern;

/// <summary>
/// Backend buffer type to use for matching tensors, as obtained via ggml_backend_dev_buffer_type
/// </summary>
public IntPtr BufferType;
}
}
135 changes: 135 additions & 0 deletions LLama/Native/LLamaTensorBufferOverrideHelper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Native
{
/// <summary>
/// Helper for creating and managing tensor buffer overrides
/// </summary>
public class LLamaTensorBufferOverrideHelper : IDisposable
{
private readonly List<IntPtr> _allocatedMemory = new();
private readonly List<LLamaModelTensorBufferOverride> _overrides = new();
private IntPtr _overrideArray = IntPtr.Zero;
private readonly Dictionary<string, IntPtr> _bufferTypeCache = new();

/// <summary>
/// Get all available buffer types
/// </summary>
/// <returns>Dictionary mapping buffer type names to their handles</returns>
public Dictionary<string, IntPtr> GetAvailableBufferTypes()
{
var result = new Dictionary<string, IntPtr>();

nuint count = NativeApi.ggml_backend_dev_count();
for (nuint i = 0; i < count; i++)
{
IntPtr dev = NativeApi.ggml_backend_dev_get(i);
IntPtr buft = NativeApi.ggml_backend_dev_buffer_type(dev);

if (buft != IntPtr.Zero)
{
IntPtr namePtr = NativeApi.ggml_backend_buft_name(buft);
string name = Marshal.PtrToStringAnsi(namePtr) ?? string.Empty;

if (!string.IsNullOrEmpty(name))
{
result[name] = buft;
_bufferTypeCache[name] = buft;
}
}
}

return result;
}

/// <summary>
/// Add a tensor buffer override
/// </summary>
/// <param name="pattern">Tensor name pattern to match</param>
/// <param name="bufferTypeName">Name of the buffer type to use</param>
/// <returns>True if the override was added successfully</returns>
public bool AddOverride(string pattern, string bufferTypeName)
{
if (string.IsNullOrEmpty(pattern) || string.IsNullOrEmpty(bufferTypeName))
return false;

// Get all buffer types if cache is empty
if (_bufferTypeCache.Count == 0)
{
GetAvailableBufferTypes();
}

// Check if we have this buffer type
if (!_bufferTypeCache.TryGetValue(bufferTypeName, out IntPtr bufferType))
return false;

// Allocate memory for the pattern string and keep track of it
byte[] patternBytes = Encoding.UTF8.GetBytes(pattern + "\0");
IntPtr patternPtr = Marshal.AllocHGlobal(patternBytes.Length);
Marshal.Copy(patternBytes, 0, patternPtr, patternBytes.Length);
_allocatedMemory.Add(patternPtr);

// Create the override
var @override = new LLamaModelTensorBufferOverride
{
Pattern = patternPtr,
BufferType = bufferType
};

_overrides.Add(@override);
return true;
}

/// <summary>
/// Apply the overrides to model parameters
/// </summary>
/// <param name="modelParams">Model parameters to update</param>
public unsafe void ApplyToModelParams(ref LLamaModelParams modelParams)
{
if (_overrides.Count == 0)
{
modelParams.tensor_buft_overrides = IntPtr.Zero;
return;
}

// Free previous array if it exists
if (_overrideArray != IntPtr.Zero)
{
Marshal.FreeHGlobal(_overrideArray);
}

// Allocate memory for the array + null terminator
int size = Marshal.SizeOf<LLamaModelTensorBufferOverride>() * (_overrides.Count + 1);
_overrideArray = Marshal.AllocHGlobal(size);
_allocatedMemory.Add(_overrideArray);

// Copy overrides to array
for (int i = 0; i < _overrides.Count; i++)
{
IntPtr elemPtr = IntPtr.Add(_overrideArray, i * Marshal.SizeOf<LLamaModelTensorBufferOverride>());
Marshal.StructureToPtr(_overrides[i], elemPtr, false);
}

// Add null terminator
IntPtr nullTermPtr = IntPtr.Add(_overrideArray, _overrides.Count * Marshal.SizeOf<LLamaModelTensorBufferOverride>());
Marshal.StructureToPtr(new LLamaModelTensorBufferOverride { Pattern = IntPtr.Zero, BufferType = IntPtr.Zero }, nullTermPtr, false);

// Update model params
modelParams.tensor_buft_overrides = _overrideArray;
}

/// <inheritdoc />
public void Dispose()
{
foreach (IntPtr ptr in _allocatedMemory)
{
Marshal.FreeHGlobal(ptr);
}
_allocatedMemory.Clear();
_overrides.Clear();
_overrideArray = IntPtr.Zero;
}
}
}
2 changes: 2 additions & 0 deletions LLama/Native/NativeApi.Load.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ private static void SetDllImportResolver()

internal const string libraryName = "llama";
internal const string llavaLibraryName = "llava_shared";
internal const string ggmlLibraryName = "ggml";
internal const string ggmlBaseLibraryName = "ggml-base";

private static INativeLibrary? _loadedLLamaLibrary = null;
private static INativeLibrary? _loadedLLavaLibrary = null;
Expand Down
31 changes: 31 additions & 0 deletions LLama/Native/NativeApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -439,5 +439,36 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
// it would expose the raw pointer to the model, without properly wrapping it in a SafeLLamaModelHandle.
//[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
//public static void llama_model* llama_get_model(SafeLLamaContextHandle ctx);

/// <summary>
/// Get the number of available backend devices
/// </summary>
/// <returns>Count of available backend devices</returns>
[DllImport(ggmlLibraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern nuint ggml_backend_dev_count();

/// <summary>
/// Get a backend device by index
/// </summary>
/// <param name="i">Device index</param>
/// <returns>Pointer to the backend device</returns>
[DllImport(ggmlLibraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr ggml_backend_dev_get(nuint i);

/// <summary>
/// Get the buffer type for a backend device
/// </summary>
/// <param name="dev">Backend device pointer</param>
/// <returns>Pointer to the buffer type</returns>
[DllImport(ggmlBaseLibraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr ggml_backend_dev_buffer_type(IntPtr dev);

/// <summary>
/// Get the name of a buffer type
/// </summary>
/// <param name="buft">Buffer type pointer</param>
/// <returns>Name of the buffer type</returns>
[DllImport(ggmlBaseLibraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr ggml_backend_buft_name(IntPtr buft);
}
}
Loading