Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 113 additions & 97 deletions ggml/src/ggml-openvino/ggml-decoder.cpp

Large diffs are not rendered by default.

140 changes: 92 additions & 48 deletions ggml/src/ggml-openvino/ggml-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,63 @@
#include <optional>
#include <vector>

struct ModelParams {
int ctx = -1;
int ctx_swa = -1;
int ctx_per_seq = -1;
int ctx_per_seq_swa = -1;
int n_seq = -1;
int n_heads = -1;
int n_heads_kv = -1;
int head_size = -1;
int32_t * rope_params = nullptr;
std::vector<int> swa_layers;

// std::vector<std::string> kv_names;

bool can_reuse_dynamically(const ModelParams & other) const {
return n_seq == other.n_seq && n_heads == other.n_heads && n_heads_kv == other.n_heads_kv &&
head_size == other.head_size && rope_params == other.rope_params && swa_layers == other.swa_layers;
}

bool can_reuse_statically(const ModelParams & other) const {
return can_reuse_dynamically(other) && ctx_per_seq == other.ctx_per_seq &&
ctx_per_seq_swa == other.ctx_per_seq_swa;
}
};

struct ComputeParams {
int n_seq_active = -1;
int seq_active_start = -1;
int attention_size = -1;
int attention_size_swa = -1;
int input_len = -1;
int token_len_per_seq = -1;
int past_kv_len = -1;
int output_len = -1;
};

class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
public:
struct NodeInfo {
ggml_tensor * node;
std::string node_name;
std::string node_op_type;
std::map<std::string, ggml_tensor *> node_inputs;
std::vector<std::string> node_inputs_names;
std::map<std::string, ggml_tensor *> node_outputs;
std::vector<std::string> node_outputs_names;
ggml_tensor * node_output;
std::string node_output_name;
int node_op_case = 0;
std::string node_op_type;
std::string node_name;
void * data_addr;
};
// Graph decoder
GgmlOvDecoder(ggml_cgraph * cgraph,
ModelParams & model_params,
ComputeParams & compute_params,
std::map<std::string, std::shared_ptr<ov::Node>> & model_weights,
bool is_static);
bool is_static,
bool is_prefill = false,
int prefill_chunk_size = 256);

// Naive graph decoder
GgmlOvDecoder(ggml_cgraph * cgraph, std::map<std::string, std::shared_ptr<ov::Node>> & model_weights);
Expand Down Expand Up @@ -66,7 +107,7 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {

virtual ov::PartialShape get_output_shape(const std::string & name) const override;

virtual ov::PartialShape get_output_shape(int node_idx, const std::string & name) const override;
virtual ov::PartialShape get_output_shape(int node_idx) const override;

virtual std::vector<size_t> get_output_stride(const std::string & name) const override;

Expand All @@ -78,7 +119,7 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {

virtual int32_t * get_output_op_params(const std::string & name) const override;

virtual int32_t * get_output_op_params(int node_idx, const std::string & name) const override;
virtual int32_t * get_output_op_params(int node_idx) const override;

virtual std::vector<std::string> get_output_names() const override;

Expand Down Expand Up @@ -116,29 +157,39 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
return m_model_weights;
}

virtual const std::vector<std::string> & get_model_output_names() const override { return m_model_output_names; }
virtual std::vector<std::string> get_model_output_names() const override {
std::vector<std::string> output_names;
output_names.reserve(m_model_outputs.size());
for (const auto & [name, tensor] : m_model_outputs) {
output_names.push_back(name);
}
return output_names;
}

virtual int get_ctx_size() const { return m_ctx; }
const std::map<std::string, ggml_tensor *> & get_model_outputs() const { return m_model_outputs; }

virtual int get_ctx_swa_size() const { return m_ctx_swa; }
virtual int get_ctx_size() const { return m_model_params.ctx; }

virtual int get_ctx_per_seq() const { return m_ctx_per_seq; }
virtual int get_ctx_swa_size() const { return m_model_params.ctx_swa; }

virtual int get_ctx_per_seq_swa() const { return m_ctx_per_seq_swa; }
virtual int get_ctx_per_seq() const { return m_model_params.ctx_per_seq; }

virtual int get_n_seq() const { return m_n_seq; }
virtual int get_ctx_per_seq_swa() const { return m_model_params.ctx_per_seq_swa; }

virtual int get_n_seq() const { return m_model_params.n_seq; }

virtual int is_swa_layer(int layer) const override {
return std::find(m_swa_layers.begin(), m_swa_layers.end(), layer) != m_swa_layers.end();
return std::find(m_model_params.swa_layers.begin(), m_model_params.swa_layers.end(), layer) !=
m_model_params.swa_layers.end();
}

int get_past_kv_len() const { return m_past_kv_len; }
int get_past_kv_len() const { return m_compute_params.past_kv_len; }

int get_input_len() const { return m_input_len; }
int get_input_len() const { return m_compute_params.input_len; }

virtual int32_t * get_rope_params() const override { return m_rope_params; }
virtual int32_t * get_rope_params() const override { return m_model_params.rope_params; }

virtual std::map<std::string, std::string> get_kv_param_res_names() const override;
// virtual std::map<std::string, std::string> get_kv_param_res_names() const override;

virtual bool is_static() const override { return m_is_static; }

Expand All @@ -159,19 +210,31 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {

void clear_model_weights() { m_model_weights.clear(); }

private:
void set_input_output(ggml_tensor * node, bool naive = false);
void add_extra_inputs();
static std::pair<ModelParams, ComputeParams> compute_llm_params(ggml_cgraph * cgraph, bool is_static);

ModelParams get_model_params() const { return m_model_params; }

ComputeParams get_compute_params() const { return m_compute_params; }

void set_model_params(const ModelParams & model_params) { m_model_params = model_params; }

void set_compute_params(const ComputeParams & compute_params) { m_compute_params = compute_params; }

bool m_is_static = false;
bool m_is_prefill = false;
int m_prefill_chunk_size = 0;

static std::vector<size_t> get_shape(const ggml_tensor * tensor);
static std::vector<size_t> get_stride(const ggml_tensor * tensor);
static ov::element::Type get_ov_type(const ggml_tensor * tensor);
int compute_op_case(const ggml_tensor * node);
std::string compute_op_type(const ggml_tensor * node);
static std::string compute_op_type(const ggml_tensor * node);

void set_llm_params();
void validate_cgraph() const;
private:
void set_input_output(ggml_tensor * node, bool naive = false);
void add_extra_inputs();
int compute_op_case(const ggml_tensor * node) const;

bool m_is_static = false;
void validate_cgraph() const;

ggml_cgraph * m_cgraph = nullptr;
std::vector<ggml_tensor *> m_nodes;
Expand All @@ -184,30 +247,11 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
std::map<std::string, std::shared_ptr<ov::Node>> m_model_extra_inputs;
std::map<std::string, std::shared_ptr<ov::Tensor>> m_model_extra_input_values;
std::map<std::string, std::shared_ptr<ov::Node>> m_model_weights;
std::vector<std::string> m_model_output_names;
std::map<std::string, ggml_tensor *> m_model_outputs;
std::vector<NodeInfo> m_node_info_list;

// Fixed for a model
int m_ctx = -1;
int m_ctx_swa = -1;
int m_ctx_per_seq = -1;
int m_ctx_per_seq_swa = -1;
int m_n_seq = -1;
int m_n_heads = -1;
int m_n_heads_kv = -1;
int m_head_size = -1;
std::vector<int> m_swa_layers;
std::vector<std::string> m_kv_names;

// Changed per inference
int m_n_seq_active = -1;
int m_seq_active_start = -1;
int m_attention_size = -1;
int m_attention_size_swa = -1;
int m_input_len = -1;
int m_token_len_per_seq = -1;
int m_past_kv_len = -1;
int32_t * m_rope_params = nullptr;
ModelParams m_model_params;
ComputeParams m_compute_params;
};

void print_tensor_address_map(const ggml_cgraph * cgraph);
Expand Down
8 changes: 4 additions & 4 deletions ggml/src/ggml-openvino/openvino/decoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class GgmlDecoder : public DecoderBase {

virtual PartialShape get_output_shape(const std::string& name) const = 0;

virtual PartialShape get_output_shape(int node_idx, const std::string& name) const = 0;
virtual PartialShape get_output_shape(int node_idx) const = 0;

virtual std::vector<size_t> get_output_stride(const std::string& name) const = 0;

Expand All @@ -51,7 +51,7 @@ class GgmlDecoder : public DecoderBase {

virtual int32_t* get_output_op_params(const std::string& name) const = 0;

virtual int32_t* get_output_op_params(int node_idx, const std::string& name) const = 0;
virtual int32_t * get_output_op_params(int node_idx) const = 0;

virtual std::vector<std::string> get_output_names() const = 0;

Expand All @@ -72,10 +72,10 @@ class GgmlDecoder : public DecoderBase {
virtual const std::map<std::string, std::shared_ptr<ov::Node>>& get_model_inputs() const = 0;
virtual const std::map<std::string, std::shared_ptr<ov::Node>>& get_model_extra_inputs() const = 0;
virtual const std::map<std::string, std::shared_ptr<ov::Node>>& get_model_weights() const = 0;
virtual const std::vector<std::string>& get_model_output_names() const = 0;
virtual std::vector<std::string> get_model_output_names() const = 0;

virtual int32_t* get_rope_params() const = 0;
virtual std::map<std::string, std::string> get_kv_param_res_names() const = 0;
// virtual std::map<std::string, std::string> get_kv_param_res_names() const = 0;

virtual bool is_static() const = 0;

Expand Down
8 changes: 2 additions & 6 deletions ggml/src/ggml-openvino/openvino/node_context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,13 @@ class NodeContext : public frontend::NodeContext {

std::string get_output_name() const { return m_output_names[0]; }

PartialShape get_output_shape(size_t index) const {
return m_decoder->get_output_shape(m_node_idx, m_output_names[index]);
}
PartialShape get_output_shape() const { return m_decoder->get_output_shape(m_node_idx); }

int32_t* get_input_op_params(size_t index) const {
return m_decoder->get_input_op_params(m_node_idx, m_input_names[index]);
}

int32_t* get_output_op_params(size_t index) const {
return m_decoder->get_output_op_params(m_node_idx, m_output_names[index]);
}
int32_t * get_output_op_params() const { return m_decoder->get_output_op_params(m_node_idx); }

ov::element::Type get_output_type(size_t index) const {
return m_decoder->get_output_type(m_output_names[index]);
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-openvino/openvino/op/cont.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ OutputVector translate_cont(const NodeContext & context) {
FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3, "Unsupported CONT case");

auto src_shape = context.get_input_shape(0).to_shape();
auto dst_shape = context.get_output_shape(0).to_shape();
auto dst_shape = context.get_output_shape().to_shape();
ov::Output<Node> res;

if (op_case == 1) {
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ OutputVector translate_flash_attn_ext(const NodeContext & context) {
auto v = context.get_input(2);
auto mask = context.get_input(3);

float * params = reinterpret_cast<float *>(context.get_output_op_params(0));
float * params = reinterpret_cast<float *>(context.get_output_op_params());
float scale = params[0];
// float max_bias = params[1];
// float logit_softcap = params[2];
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ OutputVector translate_glu_geglu(const NodeContext & context) {
src1 = split->output(1);
}

int32_t * params = context.get_output_op_params(0);
int32_t * params = context.get_output_op_params();
const int32_t swapped = params[1];
if (swapped) {
std::swap(src0, src1);
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ OutputVector translate_glu_swiglu(const NodeContext & context) {
src1 = split->output(1);
}

int32_t * params = context.get_output_op_params(0);
int32_t * params = context.get_output_op_params();
const int32_t swapped = params[1];
if (swapped) {
std::swap(src0, src1);
Expand Down
29 changes: 22 additions & 7 deletions ggml/src/ggml-openvino/openvino/op/permute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ OutputVector translate_permute(const NodeContext & context) {
if (op_case == 1) {
res = std::make_shared<ov::op::v1::Transpose>(src, perm);
} else if (op_case == 4) {
auto output_shape = context.get_output_shape(0).to_shape();
auto output_shape = context.get_output_shape().to_shape();
auto n_heads = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[1]});
auto head_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[3]});
auto n_seq_active = context.get_input("n_seq_active");
auto n_seq_active = context.has_input("n_seq_active") ?
context.get_input("n_seq_active") :
ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[0]});
auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});

auto new_shape =
Expand All @@ -49,26 +51,39 @@ OutputVector translate_permute(const NodeContext & context) {
res = std::make_shared<ov::op::v1::Transpose>(reshaped, perm);
} else {
auto cache_shape = src.get_partial_shape();
auto output_shape = context.get_output_shape(0).to_shape();
auto output_shape = context.get_output_shape().to_shape();
int64_t head_size = output_shape[3];
int64_t n_heads = output_shape[1];
int64_t ctx_per_seq = cache_shape[2].is_static() ? cache_shape[2].get_length() : -1;
int64_t n_seq = cache_shape[1].get_length();

Output<Node> attention_size;
if (op_case == 2) {
if (!context.has_input("attention_size")) {
attention_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[2]});
} else if (op_case == 2) {
attention_size = context.get_input("attention_size");
} else {
attention_size = context.get_input("attention_size_swa");
}

Output<Node> seq_active_start;
Output<Node> seq_active_end;
if (context.has_input("seq_active_start")) {
seq_active_start = context.get_input("seq_active_start");
seq_active_end = context.get_input("seq_active_end");
} else {
int64_t n_seq_active = output_shape[0];
size_t offset = *((size_t *) context.get_input_op_params(0));
int64_t seq_active_start_val = offset / context.get_input_stride(0)[0];
int64_t seq_active_end_val = seq_active_start_val + n_seq_active;
seq_active_start = ov::op::v0::Constant::create(ov::element::i64, {1}, {seq_active_start_val});
seq_active_end = ov::op::v0::Constant::create(ov::element::i64, {1}, {seq_active_end_val});
}

// 1. reshape to [n_seq, ctx_per_seq, n_heads, head_size]
// 2. slice out the active sequences
// 3. slice out the attention part in each sequence
// 4. permute
auto seq_active_start = context.get_input("seq_active_start");
auto seq_active_end = context.get_input("seq_active_end");

auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});

Expand Down
10 changes: 5 additions & 5 deletions ggml/src/ggml-openvino/openvino/op/reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace op {

OutputVector translate_reshape(const NodeContext & context) {
num_inputs_check(context, 1, 1);
if (context.get_input_shape(0) == context.get_output_shape(0)) {
if (context.get_input_shape(0) == context.get_output_shape()) {
return {context.get_input(0)};
}

Expand All @@ -29,7 +29,7 @@ OutputVector translate_reshape(const NodeContext & context) {
op_case == 1 || op_case == 2 || op_case == 3 || op_case == 4 || op_case == 5 || op_case == 6,
"Unsupported RESHAPE case");

auto output_shape = context.get_output_shape(0).to_shape();
auto output_shape = context.get_output_shape().to_shape();
std::shared_ptr<ov::Node> new_shape_node;
if (op_case == 1) {
new_shape_node = ov::op::v0::Constant::create(
Expand All @@ -50,18 +50,18 @@ OutputVector translate_reshape(const NodeContext & context) {
return {context.get_input(0).get_node_shared_ptr()->input_value(0)};

} else if (op_case == 5) {
std::vector<int64_t> shape_vec = {1, 1, -1, (int64_t) context.get_output_shape(0).to_shape()[3]};
std::vector<int64_t> shape_vec = {1, 1, -1, (int64_t) context.get_output_shape().to_shape()[3]};
new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {4}, shape_vec);

// // Alternative
// auto token_len = context.get_input("token_len");
// auto emb_size =
// ov::op::v0::Constant::create(ov::element::i64, {1}, {(int64_t) context.get_output_shape(0).to_shape()[3]});
// ov::op::v0::Constant::create(ov::element::i64, {1}, {(int64_t) context.get_output_shape().to_shape()[3]});
// auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
// new_shape_node = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one, one, token_len, emb_size}, 0);

} else if (op_case == 6) {
new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {4}, context.get_output_shape(0).to_shape());
new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {4}, context.get_output_shape().to_shape());
}
auto res = std::make_shared<ov::op::v1::Reshape>(context.get_input(0), new_shape_node, false);
return rename_outputs_with_suffix({res}, context.get_name());
Expand Down
Loading