forked from mildinvestor/katago-colab
-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.go
165 lines (141 loc) · 4.18 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
package main
import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"os"
"time"
"golang.org/x/crypto/ssh"
"moul.io/http2curl"
)
type HttpError struct {
StatusCode int `json:"statusCode"`
Msg string `json:"msg"`
Key string `json:"key"`
}
func (e *HttpError) Error() string {
return e.Msg
}
func CreateErrorWithMsg(status int, key string, msg string) error {
return &HttpError{StatusCode: status, Msg: msg, Key: key}
}
func CreateError(status int, key string) error {
return &HttpError{StatusCode: status, Msg: key, Key: key}
}
// DoHTTPRequest Sends generic http request
func DoHTTPRequest(method string, url string, headers map[string]string, body []byte) (responseBody string, err error) {
httpClient := &http.Client{}
req, _ := http.NewRequest(method, url, bytes.NewBuffer(body))
req.Close = true
if headers != nil {
for k, v := range headers {
req.Header.Set(k, v)
}
}
command, _ := http2curl.GetCurlCommand(req)
response, err := httpClient.Do(req)
if err != nil {
log.Printf("ERROR error requesting with http: %s, error: %v\n", command, err)
err = CreateError(500, "failed_do_get")
return
}
bodyBytes, err := ioutil.ReadAll(response.Body)
response.Body.Close()
if err != nil {
log.Printf("ERROR error requesting with http: %s, error: %v\n", command, err)
err = CreateError(500, "failed_read_body")
return
}
responseBody = string(bodyBytes)
if response.StatusCode < 200 || response.StatusCode >= 300 {
log.Printf("ERROR error requesting with http: %s, status code: %v, response: %s\n", command, response.StatusCode, responseBody)
err = CreateError(500, "invalid_status")
return
}
return
}
type SSHOptions struct {
Host string `json:"host"`
Port int `json:"port"`
User string `json:"user"`
}
const (
// KataGoBin the bin file path
KataGoBin string = "/content/katago"
// KataGoWeightFile the default weight file
KataGoWeightFile string = "/content/weight.bin.gz"
// KataGoConfigFile the default config file
KataGoConfigFile string = "/content/katago-colab/config/gtp_colab.cfg"
// KataGoChangeConfigScript changes the config
KataGoChangeConfigScript string = "/content/katago-colab/scripts/change_config.sh"
)
func main() {
args := os.Args[1:]
if len(args) < 2 {
log.Printf("ERROR usage: colab-katago SSH_INFO_GOOGLE_DRIVE_FILE_ID USER_PASSWORD")
return
}
fileId := args[0]
userpassword := args[1]
var newConfig *string = nil
if len(args) >= 3 {
newConfig = &args[2]
}
log.Printf("INFO using file ID: %s password: %s\n", fileId, userpassword)
sshJSONURL := "https://drive.google.com/uc?id=" + fileId
response, err := DoHTTPRequest("GET", sshJSONURL, nil, nil)
if err != nil {
log.Printf("ERROR error requestting url: %s, err: %+v\n", sshJSONURL, err)
return
}
log.Printf("ssh options\n%s", response)
sshoptions := SSHOptions{}
// parse json
err = json.Unmarshal([]byte(response), &sshoptions)
if err != nil {
log.Printf("ERROR failed parsing json: %s\n", response)
return
}
config := &ssh.ClientConfig{
Timeout: 30 * time.Second,
User: sshoptions.User,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
config.Auth = []ssh.AuthMethod{ssh.Password(userpassword)}
addr := fmt.Sprintf("%s:%d", sshoptions.Host, sshoptions.Port)
sshClient, err := ssh.Dial("tcp", addr, config)
if err != nil {
log.Fatal("failed to create ssh client", err)
return
}
defer sshClient.Close()
configFile := KataGoConfigFile
if newConfig != nil {
// start the sesssion to do it
session, err := sshClient.NewSession()
if err != nil {
log.Fatal("failed to create ssh session", err)
return
}
defer session.Close()
cmd := fmt.Sprintf("%s %s", KataGoChangeConfigScript, *newConfig)
log.Printf("DEBUG running commad:%s\n", cmd)
configFile = fmt.Sprintf("/content/gtp_colab_%s.cfg", *newConfig)
session.Run(cmd)
}
session, err := sshClient.NewSession()
if err != nil {
log.Fatal("failed to create ssh session", err)
return
}
defer session.Close()
session.Stdout = os.Stdout
session.Stderr = os.Stderr
session.Stdin = os.Stdin
cmd := fmt.Sprintf("%s gtp -model %s -config %s", KataGoBin, KataGoWeightFile, configFile)
log.Printf("DEBUG running commad:%s\n", cmd)
session.Run(cmd)
}