From ed93443302e745fdf62e605fa6ee25a0842e9474 Mon Sep 17 00:00:00 2001 From: fgh1999 Date: Wed, 21 Aug 2024 16:23:07 +0800 Subject: [PATCH] Fix length assertion in `Tensor::slice`; Fix the shape comment of `q` in `self_attention`. --- src/model.rs | 2 +- src/tensor.rs | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/model.rs b/src/model.rs index d59bf0e0..3d674ab3 100644 --- a/src/model.rs +++ b/src/model.rs @@ -144,7 +144,7 @@ impl Llama { fn self_attention( hidden_states: &mut Tensor, // (seq, n_kv_h * n_groups * dqkv) att_scores: &mut Tensor, // (n_kv_h, n_groups, seq, total_seq) - q: &Tensor, // (seq, n_kv_h * n_groups * dqkv) + q: &Tensor, // (seq, n_kv_h * n_groups, dqkv) k: &Tensor, // (total_seq, n_kv_h * dqkv) v: &Tensor, // (total_seq, n_kv_h * dqkv) n_kv_h: usize, diff --git a/src/tensor.rs b/src/tensor.rs index b56d2dd9..9864f746 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -53,7 +53,7 @@ impl Tensor { pub fn slice(&self, start: usize, shape: &Vec) -> Self { let new_length: usize = shape.iter().product(); - assert!(self.offset + start + new_length <= self.length); + assert!(new_length <= self.length && start <= self.length - new_length); Tensor { data: self.data.clone(), shape: shape.clone(), @@ -61,8 +61,6 @@ impl Tensor { length: new_length, } } - - } // Some helper functions for testing and debugging