Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package ai.reveng.toolkit.ghidra.binarysimilarity.ui.aidecompiler;

import ai.reveng.invoker.ApiException;
import ai.reveng.model.GetAiDecompilationTask;
import ai.reveng.toolkit.ghidra.core.services.api.GhidraRevengService;
import ai.reveng.toolkit.ghidra.core.services.api.TypedApiInterface;
import ai.reveng.toolkit.ghidra.core.services.logging.ReaiLoggingService;
import ai.reveng.toolkit.ghidra.plugins.ReaiPluginPackage;
import ai.reveng.toolkit.ghidra.core.services.api.types.AIDecompilationStatus;
import docking.ActionContext;
import docking.action.DockingAction;
import docking.action.ToolBarData;
Expand All @@ -32,11 +32,15 @@ public class AIDecompilationdWindow extends ComponentProviderAdapter {

private RSyntaxTextArea textArea;
private RTextScrollPane sp;
private JTextArea descriptionArea;
private JEditorPane descriptionArea;
private JLabel predictedNameLabel;
private JButton usePredictedNameButton;
private JPanel predictedNamePanel;
private JComponent component;
private Function function;
private TaskMonitorComponent taskMonitorComponent;
private final Map<Function, AIDecompilationStatus> cache = new java.util.HashMap<>();
private final Map<Function, GetAiDecompilationTask> cache = new java.util.HashMap<>();


public AIDecompilationdWindow(PluginTool tool, String owner) {
super(tool, ReaiPluginPackage.WINDOW_PREFIX + "AI Decompilation", owner);
Expand Down Expand Up @@ -125,14 +129,33 @@ private JComponent buildComponent() {

component = new JPanel(new BorderLayout());

// Create header panel to hold description and predicted name panel
JPanel headerPanel = new JPanel();
headerPanel.setLayout(new BoxLayout(headerPanel, BoxLayout.Y_AXIS));

descriptionArea = new JTextArea(10, 60);
descriptionArea.setLineWrap(true);
descriptionArea.setText("No function selected or binary not analysed yet with RevEng.AI");
// Description area
descriptionArea = new JEditorPane();
descriptionArea.setContentType("text/html");
descriptionArea.setEditable(false);
component.add(descriptionArea, BorderLayout.NORTH);
descriptionArea.setText("No function selected or binary not analysed yet with RevEng.AI");
headerPanel.add(descriptionArea);

// Visual divider
headerPanel.add(new JSeparator(SwingConstants.HORIZONTAL));

// Predicted name panel (between description and code)
predictedNamePanel = new JPanel(new FlowLayout(FlowLayout.LEFT));
predictedNameLabel = new JLabel("Predicted name: ");
usePredictedNameButton = new JButton("Use Predicted Name");
usePredictedNameButton.addActionListener(e -> applyPredictedName());
predictedNamePanel.add(predictedNameLabel);
predictedNamePanel.add(usePredictedNameButton);
predictedNamePanel.setVisible(false); // Hidden until we have a prediction
headerPanel.add(predictedNamePanel);

component.add(headerPanel, BorderLayout.NORTH);

// Code area
textArea = new RSyntaxTextArea(20, 60);
textArea.setSyntaxEditingStyle(SyntaxConstants.SYNTAX_STYLE_C);
textArea.setEditable(false);
Expand All @@ -146,20 +169,58 @@ private JComponent buildComponent() {
return component;
}

/**
* Apply the predicted function name to the current function
*/
private void applyPredictedName() {
if (function == null) {
return;
}

var cachedStatus = cache.get(function);
if (cachedStatus == null || cachedStatus.getPredictedFunctionName() == null) {
return;
}

String predictedName = cachedStatus.getPredictedFunctionName();
var program = function.getProgram();

int txId = program.startTransaction("Rename function to predicted name");
try {
function.setName(predictedName, ghidra.program.model.symbol.SourceType.USER_DEFINED);
Msg.info(this, "Renamed function to predicted name: " + predictedName);
} catch (Exception ex) {
Msg.showError(this, this.component, "Failed to rename function:", ex.getMessage(), ex);
} finally {
program.endTransaction(txId, true);
}
}

@Override
public JComponent getComponent() {
return component;
}


public void setDisplayedValuesBasedOnStatus(Function function, AIDecompilationStatus status) {

public void setDisplayedValuesBasedOnStatus(Function function, GetAiDecompilationTask status) {
this.function = function;
if (status.status().equals("success")) {
setCode(status.decompilation());
descriptionArea.setText(status.getMarkedUpSummary());
} else if (status.status().equals("error")) {
if (status.getStatus().equals("success")) {
setCode(status.getDecompilation());
descriptionArea.setText("<html>%s</html>".formatted(status.getSummary()));

// Show predicted name if available
String predictedName = status.getPredictedFunctionName();
if (predictedName != null && !predictedName.isEmpty()) {
predictedNameLabel.setText("Predicted name: " + predictedName);
predictedNamePanel.setVisible(true);
} else {
predictedNamePanel.setVisible(false);
}
} else if (status.getStatus().equals("error")) {
setCode("");
descriptionArea.setText("Decompilation failed");
predictedNamePanel.setVisible(false);
}
}

Expand All @@ -172,6 +233,7 @@ private void clear() {
this.function = null;
setCode("");
descriptionArea.setText("");
predictedNamePanel.setVisible(false);
}

public void refresh(GhidraRevengService.FunctionWithID function) {
Expand Down Expand Up @@ -215,28 +277,28 @@ public void locationChanged(ProgramLocation loc) {
}


void newStatusForFunction(Function function, AIDecompilationStatus status) {
void newStatusForFunction(Function function, GetAiDecompilationTask status) {
cache.put(function, status);
if (function == this.function) {
SwingUtilities.invokeLater(() ->
setDisplayedValuesBasedOnStatus(function, status)
);
}
if (status.status().equals("success")) {
if (status.getStatus().equals("success")) {
var logger = tool.getService(ReaiLoggingService.class);
logger.info("AI Decompilation finished for function %s: %s".formatted(function.getName(), status.decompilation()));
logger.info("AI Decompilation finished for function %s: %s".formatted(function.getName(), status.getDecompilation()));
if (!hasPendingDecompilations()) {
taskMonitorComponent.setVisible(false);
}
} else if (status.status().equals("error")) {
} else if (status.getStatus().equals("error")) {
if (!hasPendingDecompilations()) {
taskMonitorComponent.setVisible(false);
}
}
}

private boolean hasPendingDecompilations() {
return cache.values().stream().anyMatch(s -> s.status().equals("pending") || s.status().equals("running") || s.status().equals("queued"));
return cache.values().stream().anyMatch(s -> s.getStatus().equals("pending") || s.getStatus().equals("running") || s.getStatus().equals("queued"));
}
class AIDecompTask extends Task {

Expand All @@ -253,7 +315,7 @@ public AIDecompTask(PluginTool tool, GhidraRevengService.FunctionWithID function
public void run(TaskMonitor monitor) throws CancelledException {
var fID = functionWithID.functionID();
// Check if there is an existing process already, because the trigger API will fail with 400 if there is
if (service.getApi().pollAIDecompileStatus(fID).status().equals("uninitialised")) {
if (service.getApi().pollAIDecompileStatus(fID).getStatus().equals("uninitialised")) {
// Trigger the decompilation
service.getApi().triggerAIDecompilationForFunctionID(fID);
}
Expand All @@ -265,17 +327,17 @@ public void run(TaskMonitor monitor) throws CancelledException {
private void waitForDecomp(TypedApiInterface.FunctionID id, TaskMonitor monitor) throws CancelledException {
var logger = tool.getService(ReaiLoggingService.class);
var api = service.getApi();
AIDecompilationStatus lastDecompStatus = null;
GetAiDecompilationTask lastDecompStatus = null;
while (true) {
var newStatus = api.pollAIDecompileStatus(id);
if (lastDecompStatus == null || !Objects.equals(newStatus.status(), lastDecompStatus.status())) {
if (lastDecompStatus == null || !Objects.equals(newStatus.getStatus(), lastDecompStatus.getStatus())) {
lastDecompStatus = newStatus;

newStatusForFunction(functionWithID.function(), newStatus);
}
monitor.setMessage("Waiting for AI Decompilation for %s ... Current status: %s".formatted(functionWithID.function().getName(), lastDecompStatus.status()));
monitor.setMessage("Waiting for AI Decompilation for %s ... Current status: %s".formatted(functionWithID.function().getName(), lastDecompStatus.getStatus()));
monitor.checkCancelled();
switch (newStatus.status()) {
switch (newStatus.getStatus()) {
case "pending":
case "uninitialised":
case "queued":
Expand All @@ -291,10 +353,10 @@ private void waitForDecomp(TypedApiInterface.FunctionID id, TaskMonitor monitor)
monitor.setProgress(monitor.getMaximum());
return;
case "error":
logger.error("Decompilation failed: %s".formatted(newStatus.decompilation()));
logger.error("Decompilation failed: %s".formatted(newStatus.getDecompilation()));
return;
default:
throw new RuntimeException("Unknown status: %s".formatted(newStatus.status()));
throw new RuntimeException("Unknown status: %s".formatted(newStatus.getStatus()));
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ public String decompileFunctionViaAI(FunctionWithID functionWithID, TaskMonitor
// Check if there is an existing process already, because the trigger API will fail with 400 if there is
var fID = functionWithID.functionID;
var function = functionWithID.function;
if (api.pollAIDecompileStatus(fID).status().equals("uninitialised")){
if (api.pollAIDecompileStatus(fID).getStatus().equals("uninitialised")){
// Trigger the decompilation
api.triggerAIDecompilationForFunctionID(fID);
}
Expand All @@ -655,7 +655,7 @@ public String decompileFunctionViaAI(FunctionWithID functionWithID, TaskMonitor
var status = api.pollAIDecompileStatus(fID);
window.setDisplayedValuesBasedOnStatus(function, status);

switch (status.status()) {
switch (status.getStatus()) {
case "pending":
case "uninitialised":
case "queued":
Expand All @@ -670,11 +670,11 @@ public String decompileFunctionViaAI(FunctionWithID functionWithID, TaskMonitor
case "success":
monitor.setProgress(monitor.getMaximum());
window.setDisplayedValuesBasedOnStatus(function, status);
return status.decompilation();
return status.getDecompilation();
case "error":
return "Decompilation failed: %s".formatted(status.decompilation());
return "Decompilation failed: %s".formatted(status.getStatus());
default:
throw new RuntimeException("Unknown status: %s".formatted(status.status()));
throw new RuntimeException("Unknown status: %s".formatted(status.getStatus()));
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,13 +458,17 @@ public boolean triggerAIDecompilationForFunctionID(FunctionID functionID) {
}

@Override
public AIDecompilationStatus pollAIDecompileStatus(FunctionID functionID) {

HttpRequest request = requestBuilderForEndpoint("ai-decompilation/" + functionID.value(), "?summarise=true")
.GET()
.build();
return AIDecompilationStatus.fromJSONObject(sendVersion2Request(request).getJsonData());

public GetAiDecompilationTask pollAIDecompileStatus(FunctionID functionID) {
try {
var response = functionsAiDecompilationApi.getAiDecompilationTaskResult(
functionID.value(), // Long functionId
true, // summarise
true // generateInlineComments
);
return response.getData();
} catch (ApiException e) {
throw new RuntimeException("Failed to poll AI decompilation status", e);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ default boolean triggerAIDecompilationForFunctionID(FunctionID functionID) {
throw new UnsupportedOperationException("triggerAIDecompilationForFunctionID not implemented yet");
}

default AIDecompilationStatus pollAIDecompileStatus(FunctionID functionID) {
default GetAiDecompilationTask pollAIDecompileStatus(FunctionID functionID) {
throw new UnsupportedOperationException("pollAIDecompileStatus not implemented yet");
}

Expand Down
Loading
Loading