diff options
Diffstat (limited to 'test_msgnet_tls/main.go')
-rw-r--r-- | test_msgnet_tls/main.go | 171 |
1 files changed, 171 insertions, 0 deletions
diff --git a/test_msgnet_tls/main.go b/test_msgnet_tls/main.go new file mode 100644 index 0000000..c45cfdf --- /dev/null +++ b/test_msgnet_tls/main.go @@ -0,0 +1,171 @@ +package main + +// #cgo CFLAGS: -I${SRCDIR}/../salticidae/include/ +// #include <stdlib.h> +// #include "salticidae/network.h" +// void onTerm(int sig, void *); +// void onReceiveHello(msg_t *, msgnetwork_conn_t *, void *); +// void onReceiveAck(msg_t *, msgnetwork_conn_t *, void *); +// bool connHandler(msgnetwork_conn_t *, bool, void *); +// void errorHandler(SalticidaeCError *, bool, void *); +import "C" + +import ( + "encoding/binary" + "os" + "fmt" + "unsafe" + "github.com/Determinant/salticidae-go" +) + +const ( + MSG_OPCODE_HELLO salticidae.Opcode = iota + MSG_OPCODE_ACK +) + +func msgHelloSerialize(name string, text string) salticidae.Msg { + serialized := salticidae.NewDataStream() + t := make([]byte, 4) + binary.LittleEndian.PutUint32(t, uint32(len(name))) + serialized.PutData(t) + serialized.PutData([]byte(name)) + serialized.PutData([]byte(text)) + return salticidae.NewMsgMovedFromByteArray( + MSG_OPCODE_HELLO, salticidae.NewByteArrayMovedFromDataStream(serialized)) +} + +func msgHelloUnserialize(msg salticidae.Msg) (name string, text string) { + p := msg.GetPayloadByMove() + t := p.GetDataInPlace(4); length := binary.LittleEndian.Uint32(t.Get()); t.Release() + t = p.GetDataInPlace(int(length)); name = string(t.Get()); t.Release() + t = p.GetDataInPlace(p.Size()); text = string(t.Get()); t.Release() + return +} + +func msgAckSerialize() salticidae.Msg { + return salticidae.NewMsgMovedFromByteArray(MSG_OPCODE_ACK, salticidae.NewByteArray()) +} + +func checkError(err *salticidae.Error) { + if err.GetCode() != 0 { + fmt.Printf("error during a sync call: %s\n", salticidae.StrError(err.GetCode())) + os.Exit(1) + } +} + +type MyNet struct { + net salticidae.MsgNetwork + name string + peerCert salticidae.UInt256 + cname *C.char +} + +var ( + alice, bob MyNet + ec salticidae.EventContext +) + +//export onTerm +func onTerm(_ C.int, _ unsafe.Pointer) { + ec.Stop() +} + +//export onReceiveHello +func onReceiveHello(_msg *C.struct_msg_t, _conn *C.struct_msgnetwork_conn_t, userdata unsafe.Pointer) { + msg := salticidae.MsgFromC(salticidae.CMsg(_msg)) + conn := salticidae.MsgNetworkConnFromC(salticidae.CMsgNetworkConn(_conn)) + net := conn.GetNet() + name, text := msgHelloUnserialize(msg) + myName := C.GoString((*C.char)(userdata)) + fmt.Printf("[%s] %s says %s\n", myName, name, text) + ack := msgAckSerialize() + net.SendMsg(ack, conn) +} + +//export onReceiveAck +func onReceiveAck(_ *C.struct_msg_t, _conn *C.struct_msgnetwork_conn_t, userdata unsafe.Pointer) { + myName := C.GoString((*C.char)(userdata)) + fmt.Printf("[%s] the peer knows\n", myName) +} + +//export connHandler +func connHandler(_conn *C.struct_msgnetwork_conn_t, connected C.bool, userdata unsafe.Pointer) C.bool { + conn := salticidae.MsgNetworkConnFromC(salticidae.CMsgNetworkConn(_conn)) + net := conn.GetNet() + myName := C.GoString((*C.char)(userdata)) + n := alice + if myName == "bob" { n = bob } + res := true + if connected { + certHash := conn.GetPeerCert().GetDer().GetHash() + res = certHash.IsEq(n.peerCert) + if conn.GetMode() == salticidae.CONN_MODE_ACTIVE { + fmt.Printf("[%s] Connected, sending hello.\n", myName) + hello := msgHelloSerialize(myName, "Hello there!") + n.net.SendMsg(hello, conn) + } else { + status := "fail" + if res { status = "ok" } + fmt.Printf("[%s] Accepted, waiting for greetings.\n" + + "The peer certificate footprint is %s (%s)\n", + myName, certHash.GetHex(), status) + } + } else { + fmt.Printf("[%s] Disconnected, retrying.\n", myName) + err := salticidae.NewError() + net.Connect(conn.GetAddr(), false, &err) + } + return C.bool(res) +} + +//export errorHandler +func errorHandler(_err *C.struct_SalticidaeCError, fatal C.bool, _ unsafe.Pointer) { + err := (*salticidae.Error)(unsafe.Pointer(_err)) + s := "recoverable" + if fatal { s = "fatal" } + fmt.Printf("Captured %s error during an async call: %s\n", s, salticidae.StrError(err.GetCode())) +} + +func genMyNet(ec salticidae.EventContext, + name string, peerCert string, + myAddr salticidae.NetAddr, otherAddr salticidae.NetAddr) MyNet { + err := salticidae.NewError() + netconfig := salticidae.NewMsgNetworkConfig() + netconfig.EnableTLS(true) + netconfig.TLSKeyFile(name + ".pem") + netconfig.TLSCertFile(name + ".pem") + net := salticidae.NewMsgNetwork(ec, netconfig, &err); checkError(&err) + _peerCert := salticidae.NewUInt256FromByteArray(salticidae.NewByteArrayFromHex(peerCert)) + cname := C.CString(name) + n := MyNet { net: net, name: name, peerCert: _peerCert, cname: cname } + n.net.RegHandler(MSG_OPCODE_HELLO, salticidae.MsgNetworkMsgCallback(C.onReceiveHello), unsafe.Pointer(cname)) + n.net.RegHandler(MSG_OPCODE_ACK, salticidae.MsgNetworkMsgCallback(C.onReceiveAck), unsafe.Pointer(cname)) + n.net.RegConnHandler(salticidae.MsgNetworkConnCallback(C.connHandler), unsafe.Pointer(cname)) + n.net.RegErrorHandler(salticidae.MsgNetworkErrorCallback(C.errorHandler), unsafe.Pointer(cname)) + + n.net.Start() + n.net.Listen(myAddr, &err); checkError(&err) + n.net.Connect(otherAddr, false, &err); checkError(&err) + return n +} + +func main() { + ec = salticidae.NewEventContext() + + aliceAddr := salticidae.NewAddrFromIPPortString("127.0.0.1:12345") + bobAddr := salticidae.NewAddrFromIPPortString("127.0.0.1:12346") + + alice = genMyNet(ec, "alice", "ed5a9a8c7429dcb235a88244bc69d43d16b35008ce49736b27aaa3042a674043", aliceAddr, bobAddr) + bob = genMyNet(ec, "bob", "ef3bea4e72f4d0e85da7643545312e2ff6dded5e176560bdffb1e53b1cef4896", bobAddr, aliceAddr) + + ev_int := salticidae.NewSigEvent(ec, salticidae.SigEventCallback(C.onTerm), nil) + ev_int.Add(salticidae.SIGINT) + ev_term := salticidae.NewSigEvent(ec, salticidae.SigEventCallback(C.onTerm), nil) + ev_term.Add(salticidae.SIGTERM) + + ec.Dispatch() + alice.net.Stop() + bob.net.Stop() + C.free(unsafe.Pointer(alice.cname)) + C.free(unsafe.Pointer(bob.cname)) +} |