diff --git a/gemma/activations.h b/gemma/activations.h index 20d938c2..05c9ee0a 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -65,6 +65,9 @@ struct AttentionActivations { ? batch_size * layer_config.heads * 3 : batch_size * layer_config.heads, allocator)), + vit_Q(MatFactory("Q2", batch_size, layer_config.qkv_dim, allocator)), + vit_K(MatFactory("K2", seq_len, layer_config.qkv_dim, allocator)), + vit_C(MatFactory("C2", batch_size, seq_len, allocator)), pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size, config.model_dim, allocator)), att(MatFactory("att", batch_size, layer_config.heads * seq_len, @@ -96,6 +99,7 @@ struct AttentionActivations { q.AllocateAndAttachRowPtrs(row_ptrs); q_bf.AllocateAndAttachRowPtrs(row_ptrs); q_T.AllocateAndAttachRowPtrs(row_ptrs); + vit_C.AllocateAndAttachRowPtrs(row_ptrs); att_sums.AllocateAndAttachRowPtrs(row_ptrs); } @@ -104,6 +108,10 @@ struct AttentionActivations { q_bf.OverrideRows(batch_size); // q_T rows are always qkv_dim! + vit_Q.OverrideRows(batch_size); + // vit_K stays seq_len! + vit_C.OverrideRows(batch_size); + pre_att_rms_out.OverrideRows(batch_size); att.OverrideRows(batch_size); att_out.OverrideRows(batch_size); @@ -116,6 +124,10 @@ struct AttentionActivations { MatStorageT q_bf; MatStorageT q_T; // Transposed to maximize attention speed. + MatStorageT vit_Q; + MatStorageT vit_K; + MatStorageT vit_C; + MatStorageT pre_att_rms_out; MatStorageT att; // attention vector MatStorageT att_out; // attention output @@ -141,6 +153,9 @@ struct AttentionActivationsPtrs { q = activations.q; q_bf = activations.q_bf; q_T = activations.q_T; + vit_Q = activations.vit_Q; + vit_K = activations.vit_K; + vit_C = activations.vit_C; pre_att_rms_out = activations.pre_att_rms_out; att = activations.att; att_out = activations.att_out; @@ -153,6 +168,11 @@ struct AttentionActivationsPtrs { q.OverrideRows(batch_size); q_bf.OverrideRows(batch_size); // q_T rows are always qkv_dim! + + vit_Q.OverrideRows(batch_size); + // vit_K stays seq_len! + vit_C.OverrideRows(batch_size); + pre_att_rms_out.OverrideRows(batch_size); att.OverrideRows(batch_size); att_out.OverrideRows(batch_size); @@ -168,6 +188,11 @@ struct AttentionActivationsPtrs { MatPtrT q; MatPtrT q_bf; MatPtrT q_T; + + MatPtrT vit_Q; + MatPtrT vit_K; + MatPtrT vit_C; + MatPtrT pre_att_rms_out; MatPtrT att; MatPtrT att_out; diff --git a/gemma/run.cc b/gemma/run.cc index 7e2059fd..25be2c6a 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -130,7 +130,10 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, auto batch_stream_token = [&](size_t query_idx, size_t pos, int token, float) { std::string token_text; - HWY_ASSERT(gemma.Tokenizer().Decode(std::vector{token}, &token_text)); + if (!gemma.Tokenizer().Decode(std::vector{token}, &token_text)) { + if (token == -2) return true; // Gemma 3 ViT? + HWY_WARN("Failed to decode token %d.", token); + } HWY_ASSERT(pos == abs_pos); ++abs_pos; diff --git a/gemma/vit.cc b/gemma/vit.cc index b00efda5..7c3e2416 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -78,13 +78,9 @@ class VitAttention { const float query_scale = 1.0f / sqrtf(static_cast(qkv_dim)); PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); - // Shift Q, K, VT to MatStorageT. - MatStorageT Q("Q2", Extents2D(num_tokens_, qkv_dim), - env_.ctx.allocator, MatPadding::kPacked); - MatStorageT K("K2", Extents2D(seq_len, qkv_dim), env_.ctx.allocator, - MatPadding::kPacked); - MatStorageT C("C2", Extents2D(num_tokens_, seq_len), - env_.ctx.allocator, MatPadding::kPacked); + MatPtrT& Q = activations_.attention.vit_Q; + MatPtrT& K = activations_.attention.vit_K; + MatPtrT& C = activations_.attention.vit_C; // Initialize att_out to zero prior to head loop. ZeroInit(activations_.attention.att_out);