Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .gitmodules

This file was deleted.

269 changes: 171 additions & 98 deletions cmd/diode/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package main
// ./diode join -config 0xB7A5bd0345EF1Cc5E66bf61BdeC17D2461fBd968

import (
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
Expand All @@ -18,7 +18,6 @@ import (
"io"
mrand "math/rand"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
Expand All @@ -35,6 +34,10 @@ import (
"github.com/diodechain/diode_client/config"
"github.com/diodechain/diode_client/rpc"
"github.com/diodechain/diode_client/util"
"github.com/oasisprotocol/oasis-sdk/client-sdk/go/client"
oasisConfig "github.com/oasisprotocol/oasis-sdk/client-sdk/go/config"
"github.com/oasisprotocol/oasis-sdk/client-sdk/go/connection"
"github.com/oasisprotocol/oasis-sdk/client-sdk/go/modules/evm"
"golang.org/x/crypto/curve25519"
)

Expand All @@ -47,12 +50,13 @@ var (
Type: command.DaemonCommand,
SingleConnection: true,
}
dryRun = false
network = "mainnet"
rpcURL = ""
contractAddress = ""
wantWireGuard = false
wgSuffix = ""
dryRun = false
network = "mainnet"
contractAddress = ""
contractAddrBytes []byte
oasisClient *OasisClient
wantWireGuard = false
wgSuffix = ""
)

func init() {
Expand All @@ -63,26 +67,128 @@ func init() {
joinCmd.Flag.StringVar(&wgSuffix, "suffix", "", "custom suffix for WireGuard interface and files (default derived from -network)")
}

// JSON-RPC request structure
type jsonRPCRequest struct {
JSONRPC string `json:"jsonrpc"`
Method string `json:"method"`
Params []interface{} `json:"params"`
ID int `json:"id"`
const (
oasisSapphireParaTime = "sapphire"
oasisLocalRPCDefault = "127.0.0.1:4222"
oasisLocalChainCtxDefault = "0000000000000000000000000000000000000000000000000000000000000000"
oasisLocalRPCEnv = "OASIS_LOCAL_GRPC"
oasisLocalChainContextEnv = "OASIS_LOCAL_CHAIN_CONTEXT"
oasisLocalSapphireIDEnv = "OASIS_LOCAL_SAPPHIRE_ID"
oasisSimulateGasLimit uint64 = 2_000_000
)

var (
evmZeroAddress = make([]byte, 20)
evmZeroValue = []byte{0}
evmZeroGasPrice = []byte{0}
)

// OasisClient wraps the Oasis SDK connection for contract interactions.
type OasisClient struct {
conn connection.Connection
runtime connection.RuntimeClient
evm evm.V1
ctx context.Context
networkName string
rpcEndpoint string
paratimeName string
}

// JSON-RPC response structure
type jsonRPCResponse struct {
JSONRPC string `json:"jsonrpc"`
ID int `json:"id"`
Result string `json:"result"`
Error *jsonRPCError `json:"error,omitempty"`
// NewOasisClient creates a new Oasis client for the specified network.
func NewOasisClient(ctx context.Context, networkName string) (*OasisClient, error) {
netCfg, ptCfg, skipVerify, err := resolveSapphireNetwork(networkName)
if err != nil {
return nil, err
}

var conn connection.Connection
if skipVerify {
conn, err = connection.ConnectNoVerify(ctx, netCfg)
} else {
conn, err = connection.Connect(ctx, netCfg)
}
if err != nil {
return nil, fmt.Errorf("failed to connect to Oasis network: %w", err)
}

runtime := conn.Runtime(ptCfg)
return &OasisClient{
conn: conn,
runtime: runtime,
evm: runtime.Evm,
ctx: ctx,
networkName: networkName,
rpcEndpoint: netCfg.RPC,
paratimeName: oasisSapphireParaTime,
}, nil
}

// JSON-RPC error structure
type jsonRPCError struct {
Code int `json:"code"`
Message string `json:"message"`
func resolveSapphireNetwork(networkName string) (*oasisConfig.Network, *oasisConfig.ParaTime, bool, error) {
normalized := strings.ToLower(strings.TrimSpace(networkName))
switch normalized {
case "mainnet", "testnet":
netCfg := oasisConfig.DefaultNetworks.All[normalized]
if netCfg == nil {
return nil, nil, false, fmt.Errorf("unknown network: %s", networkName)
}
ptCfg := netCfg.ParaTimes.All[oasisSapphireParaTime]
if ptCfg == nil {
return nil, nil, false, fmt.Errorf("sapphire paratime not configured for %s", normalized)
}
if err := ptCfg.Validate(); err != nil {
return nil, nil, false, fmt.Errorf("invalid sapphire paratime: %w", err)
}
return netCfg, ptCfg, false, nil
case "local":
rpcEndpoint := strings.TrimSpace(os.Getenv(oasisLocalRPCEnv))
if rpcEndpoint == "" {
rpcEndpoint = oasisLocalRPCDefault
}
chainContext := strings.TrimSpace(os.Getenv(oasisLocalChainContextEnv))
if chainContext == "" {
chainContext = oasisLocalChainCtxDefault
}
sapphireID := strings.TrimSpace(os.Getenv(oasisLocalSapphireIDEnv))
if sapphireID == "" {
sapphireID = defaultSapphireID()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The defaultSapphireID function falls back to using the 'testnet' configuration to get a default Sapphire ID for the 'local' network. This dependency on testnet configuration for a local setup might be confusing for developers and makes the local environment not self-contained.

Consider making OASIS_LOCAL_SAPPHIRE_ID a required environment variable when network=local. This would make the configuration more explicit and less surprising.

}
if sapphireID == "" {
return nil, nil, false, fmt.Errorf("missing Sapphire runtime id for local network")
}

ptCfg := &oasisConfig.ParaTime{ID: sapphireID}
netCfg := &oasisConfig.Network{
ChainContext: chainContext,
RPC: rpcEndpoint,
ParaTimes: oasisConfig.ParaTimes{
Default: oasisSapphireParaTime,
All: map[string]*oasisConfig.ParaTime{
oasisSapphireParaTime: ptCfg,
},
},
}
if err := ptCfg.Validate(); err != nil {
return nil, nil, false, fmt.Errorf("invalid local Sapphire paratime id: %w", err)
}
if err := netCfg.Validate(); err != nil {
return nil, nil, false, fmt.Errorf("invalid local network config: %w", err)
}
return netCfg, ptCfg, true, nil
default:
return nil, nil, false, fmt.Errorf("invalid network: %s", networkName)
}
}

func defaultSapphireID() string {
netCfg := oasisConfig.DefaultNetworks.All["testnet"]
if netCfg == nil {
return ""
}
ptCfg := netCfg.ParaTimes.All[oasisSapphireParaTime]
if ptCfg == nil {
return ""
}
return ptCfg.ID
}

const jsondata = `
Expand All @@ -102,8 +208,24 @@ const jsondata = `
]
`

// getPropertyValues fetches multiple property values from the smart contract in one JSON-RPC batch
// getPropertyValues fetches multiple property values from the smart contract using Oasis SDK.
func getPropertyValues(deviceAddr util.Address, keys []string) (map[string]string, error) {
if oasisClient == nil {
return nil, fmt.Errorf("oasis client not initialized")
}
return oasisClient.GetPropertyValues(deviceAddr, keys)
}

// ConfidentialEVMSimulateCall performs a confidential read-only EVM call on Sapphire.
func (c *OasisClient) ConfidentialEVMSimulateCall(ctx context.Context, contractAddr []byte, caller []byte, callData []byte) ([]byte, error) {
if len(contractAddr) == 0 {
return nil, fmt.Errorf("missing contract address")
}
return c.evm.SimulateCall(ctx, client.RoundLatest, evmZeroGasPrice, oasisSimulateGasLimit, caller, contractAddr, evmZeroValue, callData)
}

// GetPropertyValues fetches property values using the Oasis SDK.
func (c *OasisClient) GetPropertyValues(deviceAddr util.Address, keys []string) (map[string]string, error) {
abi, err := abi.JSON(strings.NewReader(jsondata))
if err != nil {
return nil, fmt.Errorf("failed to parse ABI: %v", err)
Expand All @@ -118,80 +240,25 @@ func getPropertyValues(deviceAddr util.Address, keys []string) (map[string]strin
return map[string]string{}, nil
}

requests := make([]jsonRPCRequest, 0, len(keys))
idToKey := make(map[int]string, len(keys))

for idx, key := range keys {
results := make(map[string]string, len(keys))
var errs []string
for _, key := range keys {
packedData, err := method.Inputs.Pack(deviceAddr, key)
if err != nil {
return nil, fmt.Errorf("failed to pack inputs for key %s: %v", key, err)
}
callData := append(method.ID, packedData...)
callObject := map[string]interface{}{
"to": contractAddress,
"data": "0x" + hex.EncodeToString(callData),
}
reqID := idx + 1
requests = append(requests, jsonRPCRequest{
JSONRPC: "2.0",
Method: "eth_call",
Params: []interface{}{callObject, "latest"},
ID: reqID,
})
idToKey[reqID] = key
}

requestJSON, err := json.Marshal(requests)
if err != nil {
return nil, fmt.Errorf("failed to marshal batch request: %v", err)
}

resp, err := http.Post(rpcURL, "application/json", bytes.NewBuffer(requestJSON))
if err != nil {
return nil, fmt.Errorf("failed to make HTTP request: %v", err)
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %v", err)
}

var responses []jsonRPCResponse
if err := json.Unmarshal(body, &responses); err != nil {
var single jsonRPCResponse
if err2 := json.Unmarshal(body, &single); err2 != nil {
return nil, fmt.Errorf("failed to unmarshal response as batch (%v) or single (%v)", err, err2)
}
responses = []jsonRPCResponse{single}
}

results := make(map[string]string, len(keys))
var errs []string
for _, resp := range responses {
key, ok := idToKey[resp.ID]
if !ok {
continue
}
if resp.Error != nil {
errs = append(errs, fmt.Sprintf("%s: %s (code: %d)", key, resp.Error.Message, resp.Error.Code))
continue
}
if len(resp.Result) < 2 {
results[key] = ""
continue
}
decoded, err := hex.DecodeString(resp.Result[2:])
result, err := c.ConfidentialEVMSimulateCall(c.ctx, contractAddrBytes, evmZeroAddress, callData)
if err != nil {
errs = append(errs, fmt.Sprintf("%s: failed to decode result: %v", key, err))
errs = append(errs, fmt.Sprintf("%s: %v", key, err))
continue
}
if len(decoded) == 0 {
if len(result) == 0 {
results[key] = ""
continue
}
var value string
if err := method.Outputs.Unpack(&value, decoded); err != nil {
if err := method.Outputs.Unpack(&value, result); err != nil {
errs = append(errs, fmt.Sprintf("%s: failed to unpack result: %v", key, err))
continue
}
Expand Down Expand Up @@ -219,7 +286,13 @@ func logJoinContractFetch(deviceAddr util.Address, keys []string) {
if cfg == nil || cfg.Logger == nil {
return
}
cfg.Logger.Debug("Fetching join contract (device=%s, contract=%s, rpc=%s, keys=%v)", deviceAddr.HexString(), contractAddress, rpcURL, keys)
networkName := "unknown"
rpcEndpoint := "unknown"
if oasisClient != nil {
networkName = oasisClient.networkName
rpcEndpoint = oasisClient.rpcEndpoint
}
cfg.Logger.Debug("Fetching join contract (device=%s, contract=%s, network=%s, rpc=%s, keys=%v)", deviceAddr.HexString(), contractAddress, networkName, rpcEndpoint, keys)
}

func logJoinContractFetchResult(deviceAddr util.Address, keys []string, props map[string]string, err error) {
Expand Down Expand Up @@ -1360,20 +1433,20 @@ func joinHandler() (err error) {
cfg.PrintLabel("Fleet address", cfg.FleetAddr.HexString())
}

// If we have a valid contract address, set RPC URL used for eth_call
// If we have a valid contract address, initialize the Oasis SDK client.
if !contractless {
switch network {
case "mainnet":
rpcURL = "https://sapphire.oasis.io"
case "testnet":
rpcURL = "https://testnet.sapphire.oasis.io"
case "local":
rpcURL = "http://localhost:8545"
default:
return fmt.Errorf("invalid network: %s", network)
var err error
contractAddrBytes, err = hex.DecodeString(strings.TrimPrefix(contractAddress, "0x"))
if err != nil || len(contractAddrBytes) != 20 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The number 20 is used here to check the length of the contract address. This is a magic number. To improve readability and maintainability, consider defining a constant for this value, for example const expectedAddressLength = 20, and using that constant here.

return fmt.Errorf("invalid contract address: %s", contractAddress)
}
oasisClient, err = NewOasisClient(context.Background(), network)
if err != nil {
return err
}
cfg.PrintLabel("Contract Address", contractAddress)
} else {
contractAddrBytes = nil
if wantWireGuard {
cfg.PrintInfo("WireGuard key-only mode (no contract address provided)")
} else {
Expand Down
Loading
Loading