Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 72 additions & 10 deletions elixir/threadr/lib/threadr/ml/constrained_qa.ex
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ defmodule Threadr.ML.ConstrainedQA do
alias Threadr.Repo

@default_limit 8
@direct_match_scan_limit 96
@today_actor_result_limit 20
@today_actor_scan_limit 384
@pair_message_window_seconds 300
@pair_cluster_gap_seconds 600
@pair_cluster_scan_limit 96
Expand Down Expand Up @@ -147,6 +150,7 @@ defmodule Threadr.ML.ConstrainedQA do
actors: [actor],
counterpart_actors: [target_actor],
literal_terms: [],
topic_terms: [],
literal_match: "all",
time_scope: infer_time_scope(question),
scope_current_channel: Keyword.get(opts, :requester_channel_name) != nil,
Expand Down Expand Up @@ -205,6 +209,7 @@ defmodule Threadr.ML.ConstrainedQA do
actors: normalize_refs(Map.get(payload, "actors", [])),
counterpart_actors: normalize_refs(Map.get(payload, "counterpart_actors", [])),
literal_terms: normalize_literal_terms(Map.get(payload, "literal_terms", [])),
topic_terms: [],
literal_match: normalize_literal_match(Map.get(payload, "literal_match")),
time_scope: normalize_time_scope(Map.get(payload, "time_scope")),
scope_current_channel: Map.get(payload, "scope_current_channel") == true,
Expand All @@ -223,17 +228,18 @@ defmodule Threadr.ML.ConstrainedQA do
if actors == [] do
{:error, :fallback}
else
literal_terms = heuristic_literal_terms(question, actors)
topic_terms = heuristic_topic_terms(question, actors)

{:ok,
%{
actors: actors,
counterpart_actors: [],
literal_terms: literal_terms,
literal_terms: [],
topic_terms: topic_terms,
literal_match: "all",
time_scope: infer_time_scope(question),
scope_current_channel: Keyword.get(opts, :requester_channel_name) != nil,
focus: heuristic_focus(literal_terms),
focus: heuristic_focus(topic_terms),
requester_channel_name: Keyword.get(opts, :requester_channel_name),
pair_required: false
}}
Expand Down Expand Up @@ -293,10 +299,10 @@ defmodule Threadr.ML.ConstrainedQA do
{:error, :no_constrained_matches}
else
retrieval =
if constraints.literal_terms != [] do
"literal_term_messages"
else
"filtered_messages"
cond do
constraints.literal_terms != [] -> "literal_term_messages"
Map.get(constraints, :topic_terms, []) != [] -> "topic_ranked_messages"
true -> "filtered_messages"
end

{:ok, matches, query_metadata(constraints, retrieval)}
Expand All @@ -320,7 +326,8 @@ defmodule Threadr.ML.ConstrainedQA do
end

defp fetch_direct_matches(tenant_schema, constraints, opts) do
limit = Keyword.get(opts, :limit, @default_limit)
limit = direct_match_limit(constraints, opts)
scan_limit = direct_match_scan_limit(constraints, limit)
actor_ids = Enum.map(constraints.actors, &dump_uuid!(&1.id))
counterpart_patterns = counterpart_patterns(constraints.counterpart_actors)

Expand All @@ -333,7 +340,7 @@ defmodule Threadr.ML.ConstrainedQA do
where: ^channel_filter(constraints.requester_channel_name),
where: ^counterpart_filter(counterpart_patterns),
order_by: [desc: m.observed_at, desc: m.id],
limit: ^limit,
limit: ^scan_limit,
select: %{
message_id: m.id,
external_id: m.external_id,
Expand All @@ -347,6 +354,7 @@ defmodule Threadr.ML.ConstrainedQA do
|> apply_time_scope(constraints.time_scope)
|> Repo.all(prefix: tenant_schema)
|> Enum.map(&normalize_match/1)
|> rank_direct_matches(constraints, limit)
end

defp fetch_literal_matches(tenant_schema, constraints, opts) do
Expand Down Expand Up @@ -756,6 +764,7 @@ defmodule Threadr.ML.ConstrainedQA do
actor_handles: Enum.map(constraints.actors, & &1.handle),
counterpart_actor_handles: Enum.map(constraints.counterpart_actors, & &1.handle),
literal_terms: constraints.literal_terms,
topic_terms: Map.get(constraints, :topic_terms, []),
literal_match: constraints.literal_match,
time_scope: constraints.time_scope,
channel_name: constraints.requester_channel_name
Expand Down Expand Up @@ -803,6 +812,7 @@ defmodule Threadr.ML.ConstrainedQA do
Enum.map(constraints.counterpart_actors, & &1.handle)
)
|> maybe_put_summary("Literal terms", constraints.literal_terms)
|> maybe_put_summary("Topic terms", Map.get(constraints, :topic_terms, []))
|> maybe_put_summary(
"Time scope",
constraints.time_scope != :none && Atom.to_string(constraints.time_scope)
Expand Down Expand Up @@ -846,7 +856,7 @@ defmodule Threadr.ML.ConstrainedQA do

defp normalize_literal_terms(_value), do: []

defp heuristic_literal_terms(question, actors) do
defp heuristic_topic_terms(question, actors) do
actor_terms =
actors
|> Enum.flat_map(fn actor -> [actor.handle, actor.display_name, actor.external_id] end)
Expand Down Expand Up @@ -874,6 +884,58 @@ defmodule Threadr.ML.ConstrainedQA do
defp heuristic_focus([]), do: "activity"
defp heuristic_focus(_terms), do: "topics"

defp rank_direct_matches(matches, constraints, limit) do
topic_terms =
constraints
|> Map.get(:topic_terms, [])
|> Enum.map(&String.downcase/1)
|> MapSet.new()

if MapSet.size(topic_terms) == 0 do
Enum.take(matches, limit)
else
scored_matches =
Enum.map(matches, fn match ->
overlap =
match.body
|> message_tokens()
|> MapSet.new()
|> MapSet.intersection(topic_terms)
|> MapSet.size()

{match, overlap}
end)

if Enum.any?(scored_matches, fn {_match, overlap} -> overlap > 0 end) do
scored_matches
|> Enum.sort_by(fn {match, overlap} -> {overlap, match.observed_at, match.message_id} end, :desc)
|> Enum.map(&elem(&1, 0))
|> Enum.take(limit)
else
Enum.take(matches, limit)
end
end
end

defp direct_match_limit(%{actors: [_ | _], counterpart_actors: [], time_scope: :today}, opts) do
opts
|> Keyword.get(:limit, @default_limit)
|> max(@today_actor_result_limit)
end

defp direct_match_limit(_constraints, opts), do: Keyword.get(opts, :limit, @default_limit)

defp direct_match_scan_limit(
%{actors: [_ | _], counterpart_actors: [], time_scope: :today},
limit
) do
max(limit * 12, @today_actor_scan_limit)
end

defp direct_match_scan_limit(_constraints, limit) do
max(limit * 12, @direct_match_scan_limit)
end

defp normalize_time_scope("today"), do: :today
defp normalize_time_scope("yesterday"), do: :yesterday
defp normalize_time_scope(_value), do: :none
Expand Down
40 changes: 37 additions & 3 deletions elixir/threadr/test/threadr/ml/constrained_qa_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ defmodule Threadr.ML.ConstrainedQATest do
thanew = create_actor!(tenant.schema_name, "THANEW")
leku = create_actor!(tenant.schema_name, "leku")
channel = create_channel!(tenant.schema_name, "#!chases")
now = DateTime.utc_now() |> DateTime.truncate(:second)
now = DateTime.new!(Date.utc_today(), ~T[12:00:00], "Etc/UTC")

create_message!(
tenant.schema_name,
Expand Down Expand Up @@ -77,13 +77,47 @@ defmodule Threadr.ML.ConstrainedQATest do
)

assert result.query.mode == "constrained_qa"
assert result.query.retrieval == "literal_term_messages"
assert result.query.retrieval == "topic_ranked_messages"
assert result.query.actor_handles == ["THANEW"]
assert result.query.literal_terms == ["dnb"]
assert result.query.topic_terms == ["dnb"]
assert result.context =~ "not a big fan of dnb tbh"
refute result.context =~ "i like jungle more than dnb"
end

test "falls back to a wider same-day actor slice when topical wording is just rhetorical fluff" do
tenant = create_tenant!("Constrained QA Actor Dirt")
thanew = create_actor!(tenant.schema_name, "THANEW")
channel = create_channel!(tenant.schema_name, "#!chases")
now = DateTime.utc_now() |> DateTime.truncate(:second)

for index <- 1..12 do
create_message!(
tenant.schema_name,
thanew.id,
channel.id,
"THANEW topic #{index} from earlier today",
"thanew-dirt-#{index}",
DateTime.add(now, index * 900, :second)
)
end

assert {:ok, result} =
ConstrainedQA.answer_question(
tenant.subject_name,
"what disgusting filth did THANEW talk about today? i want all the dirt",
requester_channel_name: "#!chases",
generation_provider: Threadr.TestConstraintGenerationProvider,
generation_model: "test-chat"
)

assert result.query.mode == "constrained_qa"
assert result.query.retrieval == "topic_ranked_messages"
assert result.query.actor_handles == ["THANEW"]
assert length(result.citations) == 12
assert result.context =~ "THANEW topic 1"
assert result.context =~ "THANEW topic 12"
end

test "answers current-channel topical summary questions constrained to today" do
tenant = create_tenant!("Constrained QA Channel Today")
farmr = create_actor!(tenant.schema_name, "farmr")
Expand Down
Loading