Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 52 additions & 5 deletions src/decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,43 @@ float sigmoid(float x) {
return 1.0 / (1.0 + std::exp(-x));
}

static bool adjustSpanToTextBounds(const Span& span, size_t textLength) {
// Check for negative indices (overflow to large positive values)
if (span.startIdx < 0 || span.endIdx < 0) {
return false;
}

// Check that indices don't exceed text length
if (static_cast<size_t>(span.startIdx) >= textLength ||
static_cast<size_t>(span.endIdx) > textLength) {
return false;
}

// Check that startIdx <= endIdx
if (span.startIdx > span.endIdx) {
return false;
}

return true;
}

static std::string safeCopySpanText(const Span& span, const std::string& text) {
// First validate span bounds
if (!adjustSpanToTextBounds(span, text.length())) {
return "";
}

// Double-check bounds before substr call
size_t startPos = static_cast<size_t>(span.startIdx);
size_t endPos = static_cast<size_t>(span.endIdx);

if (startPos >= text.length() || endPos > text.length() || startPos > endPos) {
return "";
}

return text.substr(startPos, endPos - startPos);
}

bool Decoder::isNested(const Span& s1, const Span& s2) {
return (s1.startIdx <= s2.startIdx && s2.endIdx <= s1.endIdx) || (s2.startIdx <= s1.startIdx && s1.endIdx <= s2.endIdx);
}
Expand Down Expand Up @@ -109,11 +146,16 @@ std::vector<std::vector<Span>> SpanDecoder::decode(
Span span;
span.startIdx = tokens[batch_id][startToken].start;
span.endIdx = tokens[batch_id][endToken].end;
span.text = texts[batch_id].substr(span.startIdx, span.endIdx - span.startIdx);
span.classLabel = entities[entity];
span.prob = prob;

spans[batch_id].push_back(span);
// Safely extract span text with bounds checking
span.text = safeCopySpanText(span, texts[batch_id]);

// Skip spans with invalid indices that couldn't extract text
if (!span.text.empty()) {
spans[batch_id].push_back(span);
}
}
}

Expand Down Expand Up @@ -169,13 +211,18 @@ std::vector<std::vector<Span>> TokenDecoder::decode(
Span span;
span.startIdx = tokens[batch_id][startToken].start;
span.endIdx = tokens[batch_id][endToken].end;
span.text = texts[batch_id].substr(span.startIdx, span.endIdx - span.startIdx);
span.classLabel = entities[entity];
span.prob = score_sum / n;

spans[batch_id].push_back(span);
// Safely extract span text with bounds checking
span.text = safeCopySpanText(span, texts[batch_id]);

// Skip spans with invalid indices that couldn't extract text
if (!span.text.empty()) {
spans[batch_id].push_back(span);
}
}
}

return batchGreedySearch(spans, flatNer, multiLabel);
}
}