diff --git a/src/Solana.Unity.Dex/Orca/Address/AddressUtils.cs b/src/Solana.Unity.Dex/Orca/Address/AddressUtils.cs index 94194d91..d3f8fcde 100644 --- a/src/Solana.Unity.Dex/Orca/Address/AddressUtils.cs +++ b/src/Solana.Unity.Dex/Orca/Address/AddressUtils.cs @@ -20,28 +20,8 @@ public static class AddressUtils /// A Pda (program-derived address) off the curve, or null. public static Pda FindProgramAddress(IEnumerable seeds, PublicKey programId) { - byte nonce = 255; - - while (nonce != 0) - { - PublicKey address; - List seedsWithNonce = new List(seeds); - seedsWithNonce.Add(new byte[] { nonce }); - - //try to generate the address - bool created = PublicKey.TryCreateProgramAddress(new List(seedsWithNonce), programId, out address); - - //if succeeded, return - if (created) - { - return new Pda(address, nonce); - } - - //decrease the nonce and retry if failed - nonce--; - } - - return null; + return PublicKey.TryFindProgramAddress(seeds, programId, out PublicKey pubkey, out byte bump) + ? new Pda(pubkey, bump) : null; } /// diff --git a/src/Solana.Unity.Wallet/PublicKey.cs b/src/Solana.Unity.Wallet/PublicKey.cs index fa249849..92dab57a 100644 --- a/src/Solana.Unity.Wallet/PublicKey.cs +++ b/src/Solana.Unity.Wallet/PublicKey.cs @@ -7,6 +7,7 @@ using System.Linq; using System.Security.Cryptography; using System.Text; +using System.Threading; namespace Solana.Unity.Wallet { @@ -256,7 +257,22 @@ public static bool IsValid(ReadOnlySpan key, bool validateCurve = false) /// /// The bytes of the `ProgramDerivedAddress` string. /// - private static readonly byte[] ProgramDerivedAddressBytes = Encoding.UTF8.GetBytes("ProgramDerivedAddress"); + private static readonly byte[] ProgramDerivedAddressBytes = "ProgramDerivedAddress"u8.ToArray(); + + /// + /// A thread-local buffer for all Pda Seeds taht we can reuse + /// + private static readonly ThreadLocal> PdaSeedsBuffer = new(() => []); + + /// + /// A thread local buffer for the bump array that we can reuse + /// + private static readonly ThreadLocal BumpArray = new(() => new byte[1]); + + /// + /// A thread local sha256 instance that we can reuse + /// + private static readonly ThreadLocal Sha256 = new(SHA256.Create); /// /// Derives a program address. @@ -264,36 +280,38 @@ public static bool IsValid(ReadOnlySpan key, bool validateCurve = false) /// The address seeds. /// The program Id. /// The derived public key, returned as inline out. - /// true if it could derive the program address for the given seeds, otherwise false.. + /// true if it could derive the program address for the given seeds, otherwise false. /// Throws exception when one of the seeds has an invalid length. public static bool TryCreateProgramAddress(ICollection seeds, PublicKey programId, out PublicKey publicKey) { - MemoryStream buffer = new(PublicKeyLength * seeds.Count + ProgramDerivedAddressBytes.Length + programId.KeyBytes.Length); - + SHA256 sha256 = Sha256.Value; + sha256.Initialize(); + foreach (byte[] seed in seeds) { if (seed.Length > PublicKeyLength) { throw new ArgumentException("max seed length exceeded", nameof(seeds)); } - buffer.Write(seed,0, seed.Length); + + sha256.TransformBlock(seed,0, seed.Length, null, 0); } - buffer.Write(programId.KeyBytes, 0, programId.KeyBytes.Length); - buffer.Write(ProgramDerivedAddressBytes, 0, ProgramDerivedAddressBytes.Length); - - SHA256 sha256 = SHA256.Create(); - byte[] hash = sha256.ComputeHash(new ReadOnlySpan(buffer.GetBuffer(), 0, (int)buffer.Length).ToArray()); - + sha256.TransformBlock(programId.KeyBytes,0, programId.KeyBytes.Length, null,0); + sha256.TransformBlock(ProgramDerivedAddressBytes, 0, ProgramDerivedAddressBytes.Length, null, 0); + sha256.TransformFinalBlock([], 0, 0); + + byte[] hash = sha256.Hash!; if (hash.IsOnCurve()) { publicKey = null; return false; } - publicKey = new(hash); + + publicKey = new PublicKey(hash); return true; } - + /// /// Attempts to find a program address for the passed seeds and program Id. /// @@ -304,28 +322,45 @@ public static bool TryCreateProgramAddress(ICollection seeds, PublicKey /// True whenever the address for a nonce was found, otherwise false. public static bool TryFindProgramAddress(IEnumerable seeds, PublicKey programId, out PublicKey address, out byte bump) { - byte seedBump = 255; - List buffer = seeds.ToList(); - var bumpArray = new byte[1]; - buffer.Add(bumpArray); - - while (seedBump != 0) + List pdaSeedsBuffer = PdaSeedsBuffer.Value; + pdaSeedsBuffer.Clear(); + pdaSeedsBuffer.AddRange(seeds); + + if (pdaSeedsBuffer.Any(seed => seed.Length > PublicKeyLength)) { - bumpArray[0] = seedBump; - bool success = TryCreateProgramAddress(buffer, programId, out PublicKey derivedAddress); + throw new ArgumentException("max seed length exceeded", nameof(seeds)); + } + + byte[] bumpArray = BumpArray.Value; + SHA256 sha256 = Sha256.Value; + + for (bump = 255; ; bump--) + { + sha256.Initialize(); - if (success) + foreach (byte[] seed in pdaSeedsBuffer) + { + sha256.TransformBlock(seed, 0, seed.Length, null, 0); + } + + bumpArray[0] = bump; + sha256.TransformBlock(bumpArray, 0, 1, null, 0); + sha256.TransformBlock(programId.KeyBytes, 0, programId.KeyBytes.Length, null, 0); + sha256.TransformBlock(ProgramDerivedAddressBytes, 0, ProgramDerivedAddressBytes.Length, null, 0); + sha256.TransformFinalBlock([], 0, 0); + + byte[] hash = sha256.Hash!; + if (!hash.IsOnCurve()) { - address = derivedAddress; - bump = seedBump; + address = new PublicKey(hash); return true; } - seedBump--; + if (bump == 0) + break; } - address = null; - bump = 0; + address = null!; return false; }