Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 79 additions & 15 deletions elixir/threadr/lib/threadr/ml/constrained_qa.ex
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ defmodule Threadr.ML.ConstrainedQA do
ConversationQAIntent,
Generation,
GenerationProviderOpts,
HybridRetriever,
ReconstructionQuery,
SemanticQA
}
Expand All @@ -23,6 +24,7 @@ defmodule Threadr.ML.ConstrainedQA do
@pair_cluster_gap_seconds 600
@pair_cluster_scan_limit 96
@question_stopwords MapSet.new([
"all",
"a",
"an",
"are",
Expand All @@ -39,7 +41,9 @@ defmodule Threadr.ML.ConstrainedQA do
"it",
"me",
"should",
"talk",
"tell",
"the",
"to",
"want",
"what",
Expand Down Expand Up @@ -147,6 +151,7 @@ defmodule Threadr.ML.ConstrainedQA do
actors: [actor],
counterpart_actors: [target_actor],
literal_terms: [],
topic_terms: [],
literal_match: "all",
time_scope: infer_time_scope(question),
scope_current_channel: Keyword.get(opts, :requester_channel_name) != nil,
Expand Down Expand Up @@ -205,6 +210,7 @@ defmodule Threadr.ML.ConstrainedQA do
actors: normalize_refs(Map.get(payload, "actors", [])),
counterpart_actors: normalize_refs(Map.get(payload, "counterpart_actors", [])),
literal_terms: normalize_literal_terms(Map.get(payload, "literal_terms", [])),
topic_terms: [],
literal_match: normalize_literal_match(Map.get(payload, "literal_match")),
time_scope: normalize_time_scope(Map.get(payload, "time_scope")),
scope_current_channel: Map.get(payload, "scope_current_channel") == true,
Expand All @@ -229,7 +235,8 @@ defmodule Threadr.ML.ConstrainedQA do
%{
actors: actors,
counterpart_actors: [],
literal_terms: literal_terms,
literal_terms: [],
topic_terms: literal_terms,
literal_match: "all",
time_scope: infer_time_scope(question),
scope_current_channel: Keyword.get(opts, :requester_channel_name) != nil,
Expand Down Expand Up @@ -282,24 +289,52 @@ defmodule Threadr.ML.ConstrainedQA do
end

defp retrieve_matches(tenant_schema, constraints, opts) do
matches =
if constraints.literal_terms != [] do
fetch_literal_matches(tenant_schema, constraints, opts)
else
fetch_direct_matches(tenant_schema, constraints, opts)
end
cond do
constraints.literal_terms != [] ->
matches = fetch_literal_matches(tenant_schema, constraints, opts)

if matches == [] do
{:error, :no_constrained_matches}
else
retrieval =
if constraints.literal_terms != [] do
"literal_term_messages"
if matches == [] do
{:error, :no_constrained_matches}
else
{:ok, matches, query_metadata(constraints, "literal_term_messages")}
end

constraints.actors != [] and Map.get(constraints, :topic_terms, []) != [] ->
case fetch_hybrid_topic_matches(tenant_schema, constraints, opts) do
[] ->
matches = fetch_direct_matches(tenant_schema, constraints, opts)

if matches == [] do
{:error, :no_constrained_matches}
else
{:ok, matches, query_metadata(constraints, "filtered_messages")}
end

matches ->
{:ok, matches, query_metadata(constraints, "hybrid_topic_messages")}
end

true ->
matches = fetch_direct_matches(tenant_schema, constraints, opts)

if matches == [] do
{:error, :no_constrained_matches}
else
"filtered_messages"
{:ok, matches, query_metadata(constraints, "filtered_messages")}
end
end
end

defp fetch_hybrid_topic_matches(tenant_schema, constraints, opts) do
actor_ids = Enum.map(constraints.actors, &dump_uuid!(&1.id))

{:ok, matches, query_metadata(constraints, retrieval)}
case HybridRetriever.search_messages(
tenant_schema,
hybrid_topic_query(constraints),
hybrid_topic_opts(constraints, opts, actor_ids)
) do
{:ok, matches, _query} -> matches
{:error, _reason} -> []
end
end

Expand Down Expand Up @@ -756,6 +791,7 @@ defmodule Threadr.ML.ConstrainedQA do
actor_handles: Enum.map(constraints.actors, & &1.handle),
counterpart_actor_handles: Enum.map(constraints.counterpart_actors, & &1.handle),
literal_terms: constraints.literal_terms,
topic_terms: Map.get(constraints, :topic_terms, []),
literal_match: constraints.literal_match,
time_scope: constraints.time_scope,
channel_name: constraints.requester_channel_name
Expand Down Expand Up @@ -803,6 +839,7 @@ defmodule Threadr.ML.ConstrainedQA do
Enum.map(constraints.counterpart_actors, & &1.handle)
)
|> maybe_put_summary("Literal terms", constraints.literal_terms)
|> maybe_put_summary("Topic terms", Map.get(constraints, :topic_terms, []))
|> maybe_put_summary(
"Time scope",
constraints.time_scope != :none && Atom.to_string(constraints.time_scope)
Expand Down Expand Up @@ -1079,5 +1116,32 @@ defmodule Threadr.ML.ConstrainedQA do
defp dump_uuid!(value) when is_binary(value), do: Ecto.UUID.dump!(value)
defp dump_uuid!(value), do: value

defp hybrid_topic_query(constraints) do
case Map.get(constraints, :topic_terms, []) do
[] -> Enum.map(constraints.actors, & &1.handle) |> Enum.join(" ")
terms -> Enum.join(terms, " ")
end
end

defp hybrid_topic_opts(constraints, opts, actor_ids) do
opts
|> Keyword.put(:actor_ids, actor_ids)
|> Keyword.put(:channel_name, constraints.requester_channel_name)
|> apply_hybrid_time_scope(constraints.time_scope)
|> Keyword.put_new(:limit, @default_limit)
end

defp apply_hybrid_time_scope(opts, :today) do
{since, until} = today_bounds()
opts |> Keyword.put(:since, since) |> Keyword.put(:until, until)
end

defp apply_hybrid_time_scope(opts, :yesterday) do
{since, until} = yesterday_bounds()
opts |> Keyword.put(:since, since) |> Keyword.put(:until, until)
end

defp apply_hybrid_time_scope(opts, :none), do: opts

defp blank?(value), do: value in [nil, ""]
end
56 changes: 52 additions & 4 deletions elixir/threadr/lib/threadr/ml/conversation_summary_qa.ex
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ defmodule Threadr.ML.ConversationSummaryQA do
ConversationSummaryQAIntent,
Generation,
GenerationProviderOpts,
HybridRetriever,
ReconstructionQuery
}

Expand All @@ -31,7 +32,7 @@ defmodule Threadr.ML.ConversationSummaryQA do
resolved_opts <- resolve_summary_opts(intent, opts),
:ok <- require_time_bounds(resolved_opts),
conversations <- fetch_conversations(tenant.schema_name, resolved_opts),
messages <- fetch_summary_messages(tenant.schema_name, resolved_opts),
messages <- fetch_summary_messages(tenant.schema_name, question, resolved_opts),
true <- conversations != [] or messages != [] do
build_answer(tenant, question, intent, conversations, messages, resolved_opts)
else
Expand Down Expand Up @@ -73,7 +74,7 @@ defmodule Threadr.ML.ConversationSummaryQA do
facts_over_time: [],
context: context,
answer: answer
}}
}}
end
end

Expand Down Expand Up @@ -191,7 +192,10 @@ defmodule Threadr.ML.ConversationSummaryQA do
rows -> Threadr.ML.SemanticQA.build_context(rows)
end

Enum.join(header_lines ++ conversation_lines ++ message_window_lines ++ blankable(evidence), "\n\n")
Enum.join(
header_lines ++ conversation_lines ++ message_window_lines ++ blankable(evidence),
"\n\n"
)
end

defp citation_matches(citations) do
Expand Down Expand Up @@ -249,7 +253,24 @@ defmodule Threadr.ML.ConversationSummaryQA do
|> maybe_filter_conversation_until(Keyword.get(opts, :until))
end

defp fetch_summary_messages(tenant_schema, opts) do
defp fetch_summary_messages(tenant_schema, question, opts) do
limit = message_limit(opts)
window_messages = fetch_window_messages(tenant_schema, opts)
hybrid_messages = fetch_hybrid_summary_messages(tenant_schema, question, opts)

hybrid_ids = MapSet.new(Enum.map(hybrid_messages, & &1.message_id))

window_fill =
window_messages
|> Enum.reject(&MapSet.member?(hybrid_ids, &1.message_id))
|> Enum.take(max(limit - length(hybrid_messages), 0))

(hybrid_messages ++ window_fill)
|> Enum.uniq_by(& &1.message_id)
|> Enum.sort_by(&{&1.observed_at, &1.message_id}, :asc)
end

defp fetch_window_messages(tenant_schema, opts) do
limit = message_limit(opts)

from(m in "messages",
Expand All @@ -276,6 +297,17 @@ defmodule Threadr.ML.ConversationSummaryQA do
|> Enum.map(&normalize_message/1)
end

defp fetch_hybrid_summary_messages(tenant_schema, question, opts) do
limit = hybrid_message_limit(opts)

tenant_schema
|> HybridRetriever.search_messages(question, hybrid_summary_opts(opts, limit))
|> case do
{:ok, matches, _query} -> Enum.take(matches, limit)
{:error, _reason} -> []
end
end

defp resolve_summary_opts(intent, opts) do
opts
|> maybe_put_inferred_bounds(intent.time_scope)
Expand Down Expand Up @@ -443,6 +475,22 @@ defmodule Threadr.ML.ConversationSummaryQA do
|> min(@max_message_limit)
end

defp hybrid_message_limit(opts) do
total_limit = message_limit(opts)

total_limit
|> div(3)
|> max(8)
|> min(40)
|> min(total_limit)
end

defp hybrid_summary_opts(opts, limit) do
opts
|> Keyword.put(:limit, limit)
|> Keyword.put(:channel_name, Keyword.get(opts, :requester_channel_name))
end

defp retrieval_mode([], _messages), do: "message_window"
defp retrieval_mode(_conversations, []), do: "reconstructed_conversations"
defp retrieval_mode(_conversations, _messages), do: "reconstructed_conversations_plus_messages"
Expand Down
Loading
Loading