2323
2424namespace gcpp {
2525
26+ void KVCache::ZeroGriffinCache () {
27+ if (conv1d_cache_size != 0 ) {
28+ hwy::ZeroBytes (conv1d_cache.get (),
29+ conv1d_cache_size * sizeof (conv1d_cache[0 ]));
30+ }
31+ if (rglru_cache_size != 0 ) {
32+ hwy::ZeroBytes (rglru_cache.get (),
33+ rglru_cache_size * sizeof (rglru_cache[0 ]));
34+ }
35+ }
36+
2637// prefill_tbatch_size is the maximum number of tokens from one query to
2738// prefill at a time.
2839KVCache KVCache::Create (const ModelConfig& weights_config,
@@ -37,9 +48,9 @@ KVCache KVCache::Create(const ModelConfig& weights_config,
3748 kv_cache.kv_cache =
3849 hwy::AllocateAligned<float >(kv_cache.seq_len * size_cache_pos);
3950 }
40- size_t num_griffin_layers = weights_config.NumLayersOfType (
41- LayerAttentionType::kGriffinRecurrentBlock );
4251
52+ const size_t num_griffin_layers = weights_config.NumLayersOfType (
53+ LayerAttentionType::kGriffinRecurrentBlock );
4354 // TODO(patrickms): Add query batching support for Griffin.
4455 if (num_griffin_layers > 0 ) {
4556 size_t conv1d_width = 0 ;
@@ -49,20 +60,18 @@ KVCache KVCache::Create(const ModelConfig& weights_config,
4960 const size_t conv1d_cache_size =
5061 num_griffin_layers * (conv1d_width == 0 ? 0 : conv1d_width - 1 ) *
5162 weights_config.model_dim ;
63+ kv_cache.conv1d_cache_size = conv1d_cache_size;
5264 if (conv1d_cache_size != 0 ) {
5365 kv_cache.conv1d_cache = hwy::AllocateAligned<float >(conv1d_cache_size);
54- hwy::ZeroBytes (kv_cache.conv1d_cache .get (),
55- conv1d_cache_size * sizeof (kv_cache.conv1d_cache [0 ]));
5666 }
5767
5868 const size_t rglru_cache_size =
5969 num_griffin_layers * weights_config.model_dim ;
70+ kv_cache.rglru_cache_size = rglru_cache_size;
6071 if (rglru_cache_size != 0 ) {
6172 kv_cache.rglru_cache = hwy::AllocateAligned<float >(rglru_cache_size);
62- hwy::ZeroBytes (kv_cache.rglru_cache .get (),
63- rglru_cache_size * sizeof (kv_cache.rglru_cache [0 ]));
6473 }
65- } // kGriffinLayers
74+ } // num_griffin_layers
6675
6776 return kv_cache;
6877}
0 commit comments