Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,19 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
ORT_RETURN_IF_ERROR(t5_encoder_subgraph_->Setup(session_state, subgraph_session_state));
encoder_feeds_fetches_manager_ = t5_encoder_subgraph_->GetFeedsFetchesManager();

if (parameters_->decoder_start_token_id < 0) {
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 2,
"Encoder subgraph shall have 2 inputs when decoder_start_token_id attribute is empty");
if (!t5_encoder_subgraph_->HasLogitsOutput()) {
// New format requires start token id.
ORT_ENFORCE(parameters_->decoder_start_token_id >= 0);
} else {
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 3,
"Encoder subgraph shall have 3 inputs when decoder_start_token_id attribute is available");
if (parameters_->decoder_start_token_id < 0) {
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 2,
"Encoder subgraph shall have 2 inputs when decoder_start_token_id attribute is empty");
} else {
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 3,
"Encoder subgraph shall have 3 inputs when decoder_start_token_id attribute is available");
}
}

} else if (attribute_name == "decoder") {
ORT_ENFORCE(t5_decoder_subgraph_ == nullptr,
"SetupSubgraphExecutionInfo should only be called once for each subgraph.");
Expand Down
71 changes: 47 additions & 24 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@ class BeamSearchT5 : public BeamSearchBase<T> {
expand_buffer_int32_func_(expand_buffer_int32_func),
expand_buffer_float_func_(expand_buffer_float_func),
expand_buffer_float16_func_(expand_buffer_float16_func),
create_beam_scorer_func_(create_beam_scorer_func) {}
create_beam_scorer_func_(create_beam_scorer_func) {
// When decoder uses encoder_hidden_state, make sure the encoder outputs it.
if (decoder_subgraph_.UseEncoderHiddenState()) {
ORT_ENFORCE(encoder_subgraph_.subgraph_output_names[1] == "encoder_hidden_states");
}
ORT_ENFORCE(encoder_subgraph_.num_layers == decoder_subgraph_.num_layers);
}

#ifdef USE_CUDA
Status InitializeCuda(
Expand Down Expand Up @@ -160,7 +166,7 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
this->create_encoder_inputs_func_,
this->add_to_feeds_func_,
buffer,
decoder_input_ids,
decoder_input_ids, // new format does not use decoder_input_ids in encoder, it is still initialized here when decoder_start_token_id >= 0.
this->ort_stream_));

#ifdef DEBUG_NODE_INPUTS_OUTPUTS
Expand Down Expand Up @@ -233,35 +239,47 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches

std::vector<OrtValue> decoder_fetches;

if (current_length + 1 < parameters->max_length) {
// When encoder outputs logits (in old format), we need get the next token from logits.
if (current_length + 1 < parameters->max_length && encoder_subgraph_.HasLogitsOutput()) {
++iteration_counter;
ORT_RETURN_IF_ERROR(this->GenerateNextToken(encoder_fetches[0],
const OrtValue& logits = encoder_fetches[0];
ORT_RETURN_IF_ERROR(this->GenerateNextToken(logits,
beam_next_tokens,
beam_state,
cpu_state,
iteration_counter));
++current_length; // Increase sequence length after a new token is generated.
}

ORT_RETURN_IF_ERROR(decoder_subgraph_.CreateInitialFeeds(this->cpu_allocator_,
ReinterpretAsSpan<const int32_t>(beam_next_tokens),
this->implicit_inputs_,
encoder_feeds,
encoder_fetches,
decoder_feeds,
this->device_copy_int32_func_,
this->expand_buffer_int32_func_,
this->expand_buffer_float_func_,
this->expand_buffer_float16_func_,
parameters->num_beams,
this->ort_stream_,
decoder_subgraph_.UseSequenceAsInputIds(),
current_length,
cpu_state.sequences,
parameters->max_length,
decoder_subgraph_.has_decoder_masked_attention_,
this->cuda_device_prop_ != nullptr));
if (current_length < parameters->max_length) {
// when no logits, copy sequence (filled with start token IDs) to input_ids for decoder.
bool copy_sequence_to_input_ids = decoder_subgraph_.UseSequenceAsInputIds() || !encoder_subgraph_.HasLogitsOutput();
if (copy_sequence_to_input_ids) {
ORT_ENFORCE(current_length == cpu_state.sequences.GetSequenceLength());
}

// Generate inputs for next decoder subgraph call.
ORT_RETURN_IF_ERROR(decoder_subgraph_.CreateInitialFeeds(
this->cpu_allocator_,
ReinterpretAsSpan<const int32_t>(beam_next_tokens),
this->implicit_inputs_,
encoder_feeds,
encoder_fetches,
decoder_feeds,
this->device_copy_int32_func_,
this->expand_buffer_int32_func_,
this->expand_buffer_float_func_,
this->expand_buffer_float16_func_,
parameters->num_beams,
this->ort_stream_,
copy_sequence_to_input_ids,
cpu_state.sequences,
parameters->max_length,
decoder_subgraph_.has_decoder_masked_attention_,
this->cuda_device_prop_ != nullptr));

if (decoder_subgraph_.past_present_share_buffer_) {
// Configure buffer sharing of past and present kv cache.
decoder_fetches.reserve(static_cast<size_t>(decoder_subgraph_.GetFirstPresentOutputIndex()) +
2 * static_cast<size_t>(decoder_subgraph_.num_layers));
decoder_fetches.resize(decoder_subgraph_.GetFirstPresentOutputIndex(), OrtValue());
Expand Down Expand Up @@ -299,14 +317,19 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches

while (current_length < parameters->max_length) {
iteration_counter++;

#ifdef DEBUG_GENERATION
auto cur_len = std::to_string(current_length);
dumper->Print("***CurrentLength", cur_len, true);
dumper->Print(::onnxruntime::MakeString("Iteration=", iteration_counter,
", CurrentLength=", current_length,
", num_layers=", decoder_subgraph_.num_layers,
", decoder_feeds=", decoder_feeds.size(),
", start_token_id=", parameters->decoder_start_token_id));

for (int i = 0; i < decoder_subgraph_.GetFirstPastInputIndex(); i++) {
dumper->Print("decoder_feeds", i, true);
dumper->Print("", decoder_feeds[i]);
}

for (int i = 0; i < decoder_subgraph_.num_layers; i++) {
int self_key_idx = decoder_subgraph_.GetFirstPastInputIndex() + 2 * i;
int self_value_idx = self_key_idx + 1;
Expand Down
18 changes: 8 additions & 10 deletions onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,9 @@ Subgraph::Subgraph(
auto& subgraph_inputs = subgraph.GetInputs();
auto& subgraph_outputs = subgraph.GetOutputs();

// inputs: input_ids, position_ids, attention_mask, past_0, past_1, ...
// outputs: logits, present_0, present_1, ...
num_subgraph_inputs = static_cast<int>(subgraph_inputs.size());
num_subgraph_outputs = static_cast<int>(subgraph_outputs.size());

// CheckSubgraph will verify inputs and outputs later.
subgraph_input_names.reserve(num_subgraph_inputs);
for (int i = 0; i < num_subgraph_inputs; ++i) {
subgraph_input_names.push_back(subgraph_inputs[i]->Name());
Expand All @@ -68,10 +65,9 @@ Status Subgraph::Setup(const SessionState& session_state,
InlinedVector<std::string_view> feed_names;
feed_names.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));

// Use the first output (logits) to find device location.
// Use the first output to find device location.
const OrtDevice& default_location = utils::FindDeviceForValue(subgraph_session_state, subgraph_output_names[0]);

// The position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter.
feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end());

const auto& subgraph_map = subgraph_session_state.GetOrtValueNameIdxMap();
Expand Down Expand Up @@ -174,13 +170,15 @@ Status Subgraph::GetParameters(const ONNX_NAMESPACE::TensorShapeProto* past_shap
}

// Logits shape is like (batch_size, seq_len, vocabulary_size)
ORT_RETURN_IF(logits_shape->dim_size() != 3,
"subgraph logits output is expected to have 3 dimension, got ", logits_shape->dim_size());
if (logits_shape != nullptr) {
ORT_RETURN_IF(logits_shape->dim_size() != 3,
"subgraph logits output is expected to have 3 dimension, got ", logits_shape->dim_size());

ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0,
"subgraph past state dimension 2 shall have a positive value for vocabulary size");
ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0,
"subgraph past state dimension 2 shall have a positive value for vocabulary size");

this->vocab_size = static_cast<int>(logits_shape->dim(2).dim_value());
this->vocab_size = static_cast<int>(logits_shape->dim(2).dim_value());
}

return Status::OK();
}
Expand Down
Loading
Loading