diff --git a/.gitignore b/.gitignore index fa996640..7d2fd96e 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,5 @@ mock_* *.png # End of https://www.gitignore.io/api/go + +.idea \ No newline at end of file diff --git a/client/src/router/index.js b/client/src/router/index.js index e1671673..36ede2fd 100644 --- a/client/src/router/index.js +++ b/client/src/router/index.js @@ -10,7 +10,7 @@ import Results from '@/pages/Results' import ResponseDetails from '@/pages/ResponseDetails' import NotFound from '@/pages/NotFound' import Blank from '@/pages/Blank' -import { sendTokenRequest, sendCodeRequest } from '../bin/traqAuth' +import { getRequest2Callback, redirect2AuthEndpoint } from '@/util/api.js' Vue.use(Router) @@ -57,10 +57,7 @@ const router = new Router({ { path: '/questionnaires/:id/edit', name: 'QuestionnaireDetailsEdit', - component: QuestionnaireDetails, - meta: { - requiresTraqAuth: true - } + component: QuestionnaireDetails }, { path: '/results/:id', @@ -89,32 +86,12 @@ const router = new Router({ name: 'Callback', component: Blank, beforeEnter: async (to, _, next) => { - const clearSessionStorage = () => { - sessionStorage.removeItem('nextRoute') - sessionStorage.removeItem('previousRoute') - sessionStorage.removeItem(`traq-auth-code-verifier-${state}`) - } - - const code = to.query.code - const state = to.query.state - const codeVerifier = sessionStorage.getItem( - `traq-auth-code-verifier-${state}` - ) - if (!code || !codeVerifier) { - let previousRoute = sessionStorage.getItem('previousRoute') - if (!previousRoute) previousRoute = '/targeted' - clearSessionStorage() - next(previousRoute) - return + await getRequest2Callback(to) + const destination = sessionStorage.getItem('destination') + if (destination) { + next(destination) } - - const res = await sendTokenRequest(code, codeVerifier) - store.commit('traq/setAccessToken', res.data.access_token) - - let nextRoute = sessionStorage.getItem('nextRoute') - if (!nextRoute) nextRoute = '/targeted' - clearSessionStorage() - next(nextRoute) + next() } } ], @@ -129,6 +106,11 @@ const router = new Router({ }) router.beforeEach(async (to, from, next) => { + console.log(to.name) + if (to.name === 'Callback') { + next() + return + } // traQにログイン済みかどうか調べる if (!store.state.me) { await store.dispatch('whoAmI') @@ -136,34 +118,8 @@ router.beforeEach(async (to, from, next) => { if (!store.state.me) { // 未ログインの場合、traQのログインページに飛ばす - const traQLoginURL = 'https://q.trap.jp/login?redirect=' + location.href - location.href = traQLoginURL - } - - if (to.meta.requiresTraqAuth) { - await store.dispatch('traq/ensureToken') - if (!store.state.traq.accessToken) { - const message = - 'アンケートの編集・作成にはtraQアカウントへのアクセスが必要です。OKを押すとtraQに飛びます。' - if (window.confirm(message)) { - sessionStorage.setItem('nextRoute', to.path) // traQでのトークン取得後に飛ばすルート - sessionStorage.setItem('previousRoute', from.path) // traQでのトークン取得失敗時に飛ばすルート - await sendCodeRequest() - - // traQのconsentページに飛ぶ前にnextが表示されることを防ぐ - next(false) - return - } else { - // キャンセルを押された場合は元のルートに戻る - if (from.path !== to.path) { - next(from.path) - } else { - // url直打ちなどでアクセスされた場合 - next('/targeted') - } - return - } - } + sessionStorage.setItem(`destination`, to.fullPath) + await redirect2AuthEndpoint() } next() diff --git a/client/src/util/api.js b/client/src/util/api.js new file mode 100644 index 00000000..49b51db8 --- /dev/null +++ b/client/src/util/api.js @@ -0,0 +1,24 @@ +import axios from 'axios' + +export const traQBaseURL = 'https://q.trap.jp/api/v3' +axios.defaults.baseURL = + process.env.NODE_ENV === 'development' + ? 'http://localhost:8080/api' + : 'https://anke-to.trap.jp/api' + +export async function redirect2AuthEndpoint() { + const data = (await axios.get('/oauth/generate/code')).data + + const authorizationEndpointUrl = new URL(data) + + window.location.assign(authorizationEndpointUrl.toString()) +} + +export async function getRequest2Callback(to) { + return axios.get('/oauth/callback', { + params: { + code: to.query.code, + state: to.query.state + } + }) +} diff --git a/docker/dev/Dockerfile b/docker/dev/Dockerfile index 36bc51a6..24244103 100644 --- a/docker/dev/Dockerfile +++ b/docker/dev/Dockerfile @@ -10,8 +10,8 @@ RUN go mod download ENV DOCKERIZE_VERSION v0.6.1 RUN apk add --no-cache openssl \ - && wget https://github.com/jwilder/dockerize/releases/download/$DOCKERIZE_VERSION/dockerize-alpine-linux-amd64-$DOCKERIZE_VERSION.tar.gz \ - && tar -C /usr/local/bin -xzvf dockerize-alpine-linux-amd64-$DOCKERIZE_VERSION.tar.gz \ - && rm dockerize-alpine-linux-amd64-$DOCKERIZE_VERSION.tar.gz + && wget https://github.com/jwilder/dockerize/releases/download/$DOCKERIZE_VERSION/dockerize-alpine-linux-amd64-$DOCKERIZE_VERSION.tar.gz \ + && tar -C /usr/local/bin -xzvf dockerize-alpine-linux-amd64-$DOCKERIZE_VERSION.tar.gz \ + && rm dockerize-alpine-linux-amd64-$DOCKERIZE_VERSION.tar.gz ENTRYPOINT dockerize -timeout 10s -wait tcp://mysql:3306 air -c docker/dev/.air.toml \ No newline at end of file diff --git a/docker/dev/docker-compose.yaml b/docker/dev/docker-compose.yaml index 27005b6d..d00235f9 100644 --- a/docker/dev/docker-compose.yaml +++ b/docker/dev/docker-compose.yaml @@ -14,6 +14,9 @@ services: MARIADB_DATABASE: anke-to TZ: Asia/Tokyo GO111MODULE: "on" + CLIENT_ID: + CLIENT_SECRET: + SESSION_SECRET: secret ports: - "1323:1323" volumes: diff --git a/docs/swagger/swagger.yaml b/docs/swagger/swagger.yaml index 5630757f..04acfeda 100644 --- a/docs/swagger/swagger.yaml +++ b/docs/swagger/swagger.yaml @@ -22,6 +22,7 @@ tags: - name: user - name: group - name: result + - name: oauth paths: /questionnaires: get: @@ -461,8 +462,45 @@ paths: description: 結果を閲覧する権限がありません。 '500': description: アンケートの回答の詳細情報一覧が取得できませんでした + '/oauth/generate/code': + get: + operationId: getCode + tags: + - oauth + summary: Oauthの詳細を取得 + description: Oauthの詳細を取得 + responses: + '200': + description: 成功 + content: + application/json: + schema: + $ref: '#/components/schemas/OAuthCode' + '500': + description: 失敗 + '/oauth/callback': + parameters: + - $ref: '#/components/parameters/codeInQuery' + get: + tags: + - oauth + summary: OAuthのコールバック + description: OAuthのコールバック + operationId: callback + responses: + '200': + description: 成功 + '302': + description: 失敗時。認証ページへリダイレクト components: parameters: + codeInQuery: + name: code + in: query + required: true + description: OAuth2.0のcode + schema: + type: string answeredInQuery: name: answered in: query @@ -539,6 +577,17 @@ components: schema: type: string schemas: + OAuthCode: + type: object + properties: + code_challenge: + type: string + code_challenge_method: + type: string + client_id: + type: string + response_type: + type: string AnsweredType: type: string description: アンケート検索時に回答済みかの状態での絞り込み @@ -618,7 +667,7 @@ components: - administrators NewQuestionnaireResponse: allOf: - - $ref: '#/components/schemas/QuestionnaireUser' + - $ref: '#/components/schemas/QuestionnaireUser' Questionnaire: type: object properties: @@ -687,20 +736,20 @@ components: - respondents QuestionnaireMyTargeted: allOf: - - $ref: '#/components/schemas/Questionnaire' - - type: object - properties: - responded_at: - type: string - format: date-time - has_response: - type: boolean - description: 回答済みあるいは下書きが存在する - required: - - responded_at - - has_response + - $ref: '#/components/schemas/Questionnaire' + - type: object + properties: + responded_at: + type: string + format: date-time + has_response: + type: boolean + description: 回答済みあるいは下書きが存在する + required: + - responded_at + - has_response QuestionnaireMyAdministrates: - allOf: + allOf: - $ref: '#/components/schemas/QuestionnaireUser' - type: object properties: @@ -716,16 +765,16 @@ components: - respondents QuestionnaireUser: allOf: - - $ref: '#/components/schemas/Questionnaire' - - type: object - properties: - targets: - $ref: '#/components/schemas/Users' - administrators: - $ref: '#/components/schemas/Users' - required: - - targets - - administrators + - $ref: '#/components/schemas/Questionnaire' + - type: object + properties: + targets: + $ref: '#/components/schemas/Users' + administrators: + $ref: '#/components/schemas/Users' + required: + - targets + - administrators QuestionType: type: string example: Text @@ -800,38 +849,38 @@ components: - scale_max NewQuestion: allOf: - - $ref: '#/components/schemas/QuestionBase' - - type: object - properties: - questionnaireID: - type: integer - example: 1 - required: - - questionnaireID + - $ref: '#/components/schemas/QuestionBase' + - type: object + properties: + questionnaireID: + type: integer + example: 1 + required: + - questionnaireID Question: allOf: - - $ref: '#/components/schemas/NewQuestion' - - type: object - properties: - questionID: - type: integer - example: 1 - required: - - questionID + - $ref: '#/components/schemas/NewQuestion' + - type: object + properties: + questionID: + type: integer + example: 1 + required: + - questionID QuestionDetails: allOf: - - $ref: '#/components/schemas/QuestionBase' - - type: object - properties: - questionID: - type: integer - example: 1 - created_at: - type: string - format: date-time - required: - - questionID - - created_at + - $ref: '#/components/schemas/QuestionBase' + - type: object + properties: + questionID: + type: integer + example: 1 + created_at: + type: string + format: date-time + required: + - questionID + - created_at NewResponse: type: object properties: @@ -936,14 +985,14 @@ components: - question_type ResponseResult: allOf: - - $ref: '#/components/schemas/Response' - - type: object - properties: - traqID: - type: string - example: lolico - required: - - traqID + - $ref: '#/components/schemas/Response' + - type: object + properties: + traqID: + type: string + example: lolico + required: + - traqID required: - submitted_at Users: diff --git a/go.mod b/go.mod index fe6a6f19..4308e3a0 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 // indirect golang.org/x/mod v0.5.0 // indirect golang.org/x/net v0.0.0-20210913180222-943fd674d43e // indirect - golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f + golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 // indirect golang.org/x/text v0.3.7 // indirect @@ -61,7 +61,15 @@ require ( require ( github.com/go-sql-driver/mysql v1.6.0 // indirect + github.com/gorilla/sessions v1.2.1 github.com/prometheus/client_golang v1.11.0 + github.com/thanhpk/randstr v1.0.4 gopkg.in/guregu/null.v4 v4.0.0 gorm.io/plugin/prometheus v0.0.0-20210820101226-2a49866f83ee ) + +require ( + github.com/gorilla/context v1.1.1 // indirect + github.com/gorilla/securecookie v1.1.1 // indirect + github.com/srinathgs/mysqlstore v0.0.0-20200417050510-9cbb9420fc4c // indirect +) diff --git a/go.sum b/go.sum index 74fb483f..9ee5d075 100644 --- a/go.sum +++ b/go.sum @@ -194,10 +194,13 @@ github.com/google/wire v0.5.0/go.mod h1:ngWDr9Qvq3yZA10YrxfyGELY/AFWGVpy9c1LTRi1 github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= +github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= +github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= @@ -389,6 +392,8 @@ github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4k github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= github.com/spf13/pflag v1.0.1/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/srinathgs/mysqlstore v0.0.0-20200417050510-9cbb9420fc4c h1:HT6QRF79dL2Ed6HCrX9RufkxFGo7+NPkgYF1Uzvv/js= +github.com/srinathgs/mysqlstore v0.0.0-20200417050510-9cbb9420fc4c/go.mod h1:kt46Hd+lF0rtpeRgOvYSWYJItOAd73EKkIBZFbX7TXs= github.com/streadway/amqp v0.0.0-20190404075320-75d898a42a94/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw= github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw= github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5JnDBl6z3cMAg/SywNDC5ABu5ApDIw6lUbRmI= @@ -400,6 +405,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/thanhpk/randstr v1.0.4 h1:IN78qu/bR+My+gHCvMEXhR/i5oriVHcTB/BJJIRTsNo= +github.com/thanhpk/randstr v1.0.4/go.mod h1:M/H2P1eNLZzlDwAzpkkkUvoyNNMbzRGhESZuEQk3r0U= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/uber/jaeger-client-go v2.25.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= github.com/uber/jaeger-lib v2.4.0+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U= @@ -523,8 +530,8 @@ golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4Iltr golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f h1:Qmd2pbz05z7z6lm0DrgQVVPuBm92jqujBKMHMOlOQEw= -golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 h1:RerP+noqYHUQ8CMRcPlC2nvTa4dcBIjegkuWdcUDuqg= +golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/model/session.go b/model/session.go new file mode 100644 index 00000000..fa1c760d --- /dev/null +++ b/model/session.go @@ -0,0 +1,9 @@ +//go:generate mockgen -source=$GOFILE -destination=mock_$GOPACKAGE/mock_$GOFILE + +package model + +import "github.com/srinathgs/mysqlstore" + +type ISession interface { + Get() (*mysqlstore.MySQLStore,error) +} diff --git a/model/session_impl.go b/model/session_impl.go new file mode 100644 index 00000000..593c8692 --- /dev/null +++ b/model/session_impl.go @@ -0,0 +1,28 @@ +package model + +import ( + "fmt" + "github.com/srinathgs/mysqlstore" + "os" +) + +type Session struct { +} + +func (s *Session) Get() (*mysqlstore.MySQLStore, error) { + _db, err := db.DB() + if err != nil { + return nil, fmt.Errorf("failed to get sql.DB :%w", err) + } + + store, err := mysqlstore.NewMySQLStoreFromConnection(_db, "sessions", "/", 60*60*24*14, []byte(os.Getenv("SESSION_SECRET"))) + if err != nil { + return nil, fmt.Errorf("failed to create session store:%w", err) + } + + return store, nil +} + +func NewSession() *Session { + return &Session{} +} diff --git a/router.go b/router.go index 389a4d19..66434fd5 100644 --- a/router.go +++ b/router.go @@ -4,6 +4,7 @@ import ( "github.com/labstack/echo-contrib/prometheus" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" + "log" ) // SetRouting ルーティングの設定 @@ -16,7 +17,10 @@ func SetRouting(port string) { p := prometheus.NewPrometheus("echo", nil) p.Use(e) - api := InjectAPIServer() + api, err := InjectAPIServer() + if err != nil { + log.Panicln(err) + } // Static Files e.Static("/", "client/dist") @@ -29,6 +33,18 @@ func SetRouting(port string) { e.File("/favicon.ico", "client/dist/favicon.ico") e.File("*", "client/dist/index.html") + + e.Use(api.SessionMiddleware()) + + oauthAPI := e.Group("/api") + { + apiOauth := oauthAPI.Group("/oauth") + { + apiOauth.GET("/callback", api.Callback) + apiOauth.GET("/generate/code", api.GetCode) + } + } + echoAPI := e.Group("/api", api.SetValidatorMiddleware, api.SetUserIDMiddleware, api.TraPMemberAuthenticate) { apiQuestionnnaires := echoAPI.Group("/questionnaires") @@ -58,10 +74,7 @@ func SetRouting(port string) { apiUsers := echoAPI.Group("/users") { - /* - TODO - apiUsers.GET("") - */ + apiUsersMe := apiUsers.Group("/me") { apiUsersMe.GET("", api.GetUsersMe) diff --git a/router/api.go b/router/api.go index 3981794b..ccd1e9f3 100644 --- a/router/api.go +++ b/router/api.go @@ -8,10 +8,11 @@ type API struct { *Response *Result *User + *Oauth } // NewAPI APIのコンストラクタ -func NewAPI(middleware *Middleware, questionnaire *Questionnaire, question *Question, response *Response, result *Result, user *User) *API { +func NewAPI(middleware *Middleware, questionnaire *Questionnaire, question *Question, response *Response, result *Result, user *User, oauth *Oauth) *API { return &API{ Middleware: middleware, Questionnaire: questionnaire, @@ -19,5 +20,6 @@ func NewAPI(middleware *Middleware, questionnaire *Questionnaire, question *Ques Response: response, Result: result, User: user, + Oauth: oauth, } } diff --git a/router/middleware.go b/router/middleware.go index b638e536..da9b899b 100644 --- a/router/middleware.go +++ b/router/middleware.go @@ -3,6 +3,8 @@ package router import ( "errors" "fmt" + "github.com/traPtitech/anke-to/router/session" + "github.com/traPtitech/anke-to/traq" "net/http" "strconv" @@ -18,18 +20,24 @@ type Middleware struct { model.IRespondent model.IQuestion model.IQuestionnaire + session.IStore + traq.IUser } // NewMiddleware Middlewareのコンストラクタ -func NewMiddleware(administrator model.IAdministrator, respondent model.IRespondent, question model.IQuestion, questionnaire model.IQuestionnaire) *Middleware { +func NewMiddleware(IAdministrator model.IAdministrator, IRespondent model.IRespondent, IQuestion model.IQuestion, IQuestionnaire model.IQuestionnaire, IStore session.IStore, IUser traq.IUser) *Middleware { return &Middleware{ - IAdministrator: administrator, - IRespondent: respondent, - IQuestion: question, - IQuestionnaire: questionnaire, + IAdministrator: IAdministrator, + IRespondent: IRespondent, + IQuestion: IQuestion, + IQuestionnaire: IQuestionnaire, + IStore: IStore, + IUser: IUser, } } + + const ( validatorKey = "validator" userIDKey = "userID" @@ -38,7 +46,7 @@ const ( questionIDKey = "questionID" ) -func (*Middleware) SetValidatorMiddleware(next echo.HandlerFunc) echo.HandlerFunc { +func (m *Middleware) SetValidatorMiddleware(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { validate := validator.New() c.Set(validatorKey, validate) @@ -51,14 +59,39 @@ func (*Middleware) SetValidatorMiddleware(next echo.HandlerFunc) echo.HandlerFun 暫定的にハードコーディングで対応*/ var adminUserIDs = []string{"temma", "sappi_red", "ryoha", "mazrean", "xxarupakaxx", "asari"} -// SetUserIDMiddleware X-Showcase-UserからユーザーIDを取得しセットする -func (*Middleware) SetUserIDMiddleware(next echo.HandlerFunc) echo.HandlerFunc { +func (m *Middleware) SessionMiddleware() echo.MiddlewareFunc { + return m.IStore.GetMiddleware() +} + +// SetUserIDMiddleware SessionからUserIDを取得 +func (m *Middleware) SetUserIDMiddleware(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - userID := c.Request().Header.Get("X-Showcase-User") - if userID == "" { - userID = "mds_boy" + sess, err := m.IStore.GetSession(c) + if errors.Is(err, session.ErrNoSession) { + return echo.NewHTTPError(http.StatusUnauthorized, "no session") + } + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err) + } + + userID, err := sess.GetUserID() + if err != nil && !errors.Is(err, session.ErrNoValue) { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID :%w", err)) } + if errors.Is(err, session.ErrNoValue) { + token, err := sess.GetToken() + if errors.Is(err, session.ErrNoValue) { + return echo.NewHTTPError(http.StatusUnauthorized, "no token") + } + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get token :%w", err)) + } + userID, err = m.IUser.GetMyID(token) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get UserID:%w", err)) + } + } c.Set(userIDKey, userID) return next(c) @@ -66,7 +99,7 @@ func (*Middleware) SetUserIDMiddleware(next echo.HandlerFunc) echo.HandlerFunc { } // TraPMemberAuthenticate traP部員かの認証 -func (*Middleware) TraPMemberAuthenticate(next echo.HandlerFunc) echo.HandlerFunc { +func (m *Middleware) TraPMemberAuthenticate(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { userID, err := getUserID(c) if err != nil { @@ -85,7 +118,7 @@ func (*Middleware) TraPMemberAuthenticate(next echo.HandlerFunc) echo.HandlerFun } // TrapRateLimitMiddlewareFunc traP IDベースのリクエスト制限 -func (*Middleware) TrapRateLimitMiddlewareFunc() echo.MiddlewareFunc { +func (m *Middleware) TrapRateLimitMiddlewareFunc() echo.MiddlewareFunc { config := middleware.RateLimiterConfig{ Store: middleware.NewRateLimiterMemoryStore(5), IdentifierExtractor: func(c echo.Context) (string, error) { diff --git a/router/middleware_test.go b/router/middleware_test.go index 84503255..1d2febd1 100644 --- a/router/middleware_test.go +++ b/router/middleware_test.go @@ -3,6 +3,8 @@ package router import ( "errors" "fmt" + "github.com/traPtitech/anke-to/router/session/mock_session" + "github.com/traPtitech/anke-to/traq/mock_traq" "net/http" "net/http/httptest" "strconv" @@ -39,8 +41,9 @@ func TestSetUserIDMiddleware(t *testing.T) { mockAdministrator := mock_model.NewMockIAdministrator(ctrl) mockQuestionnaire := mock_model.NewMockIQuestionnaire(ctrl) mockQuestion := mock_model.NewMockIQuestion(ctrl) - - middleware := NewMiddleware(mockAdministrator, mockRespondent, mockQuestion, mockQuestionnaire) + mockStore := mock_session.NewMockIStore(ctrl) + mockUser := mock_traq.NewMockIUser(ctrl) + middleware := NewMiddleware(mockAdministrator, mockRespondent, mockQuestion, mockQuestionnaire, mockStore, mockUser) type args struct { userID string @@ -113,8 +116,9 @@ func TestTraPMemberAuthenticate(t *testing.T) { mockAdministrator := mock_model.NewMockIAdministrator(ctrl) mockQuestionnaire := mock_model.NewMockIQuestionnaire(ctrl) mockQuestion := mock_model.NewMockIQuestion(ctrl) - - middleware := NewMiddleware(mockAdministrator, mockRespondent, mockQuestion, mockQuestionnaire) + mockStore := mock_session.NewMockIStore(ctrl) + mockUser := mock_traq.NewMockIUser(ctrl) + middleware := NewMiddleware(mockAdministrator, mockRespondent, mockQuestion, mockQuestionnaire, mockStore, mockUser) type args struct { userID string @@ -181,8 +185,9 @@ func TestResponseReadAuthenticate(t *testing.T) { mockAdministrator := mock_model.NewMockIAdministrator(ctrl) mockQuestionnaire := mock_model.NewMockIQuestionnaire(ctrl) mockQuestion := mock_model.NewMockIQuestion(ctrl) - - middleware := NewMiddleware(mockAdministrator, mockRespondent, mockQuestion, mockQuestionnaire) + mockStore := mock_session.NewMockIStore(ctrl) + mockUser := mock_traq.NewMockIUser(ctrl) + middleware := NewMiddleware(mockAdministrator, mockRespondent, mockQuestion, mockQuestionnaire, mockStore, mockUser) type args struct { userID string @@ -407,8 +412,9 @@ func TestResultAuthenticate(t *testing.T) { mockAdministrator := mock_model.NewMockIAdministrator(ctrl) mockQuestionnaire := mock_model.NewMockIQuestionnaire(ctrl) mockQuestion := mock_model.NewMockIQuestion(ctrl) - - middleware := NewMiddleware(mockAdministrator, mockRespondent, mockQuestion, mockQuestionnaire) + mockStore := mock_session.NewMockIStore(ctrl) + mockUser := mock_traq.NewMockIUser(ctrl) + middleware := NewMiddleware(mockAdministrator, mockRespondent, mockQuestion, mockQuestionnaire, mockStore, mockUser) type args struct { haveReadPrivilege bool diff --git a/router/oauth.go b/router/oauth.go new file mode 100644 index 00000000..4f30ecc3 --- /dev/null +++ b/router/oauth.go @@ -0,0 +1,113 @@ +package router + +import ( + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "net/http" + "os" + + "github.com/labstack/echo/v4" + "github.com/traPtitech/anke-to/router/session" + "golang.org/x/oauth2" +) + +var ( + clientID = os.Getenv("CLIENT_ID") + clientSecret = os.Getenv("CLIENT_SECRET") + baseURL = "https://q.trap.jp/api/v3" +) + +type Oauth struct { + config *oauth2.Config + sessStore session.IStore +} + +func NewOauth(sessStore session.IStore) *Oauth { + return &Oauth{ + config: &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + Endpoint: oauth2.Endpoint{ + AuthURL: fmt.Sprintf("%s/%s", baseURL, "oauth2/authorize"), + TokenURL: fmt.Sprintf("%s/%s", baseURL, "oauth2/token"), + }, + Scopes: []string{"read"}, + }, + sessStore: sessStore, + } +} + +// GetCode GET /oauth/generate/code +func (o *Oauth) GetCode(c echo.Context) error { + sess, err := o.sessStore.GetSession(c) + if err != nil && !errors.Is(err, session.ErrNoSession) { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get session:%w", err)) + } + + verifier := RandomString(90) + sess.SetVerifier(verifier) + + state := RandomString(32) + sess.SetState(state) + + hash := sha256.Sum256([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(hash[:]) + + challengeOption := oauth2.SetAuthURLParam("code_challenge", challenge) + methodOption := oauth2.SetAuthURLParam("code_challenge_method", "S256") + + authURL := o.config.AuthCodeURL(state, challengeOption, methodOption) + + if err = sess.Save(); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to save session:%w", err)) + } + + return c.String(http.StatusOK, authURL) +} + +// Callback GET /oauth/callback +func (o *Oauth) Callback(c echo.Context) error { + sess, err := o.sessStore.GetSession(c) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get session :%w", err)) + } + + state := c.QueryParam("state") + sessState, err := sess.GetState() + if err != nil { + if errors.Is(err, session.ErrNoValue) { + return echo.NewHTTPError(http.StatusUnauthorized, fmt.Errorf("failed to get state for no value : %w", err)) + } + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get state :%w", err)) + } + + if state != sessState { + return echo.NewHTTPError(http.StatusUnauthorized, "failed to match state") + } + + code := c.QueryParam("code") + verifier, err := sess.GetVerifier() + if err != nil { + if errors.Is(err, session.ErrNoValue) { + return echo.NewHTTPError(http.StatusUnauthorized, fmt.Errorf("failed to get verifier for no value:%w", err)) + } + + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get verifier :%w", err)) + } + + codeChallengeOption := oauth2.SetAuthURLParam("code_verifier", verifier) + token, err := o.config.Exchange(c.Request().Context(), code, codeChallengeOption) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to exchange token: %w", err)) + } + + sess.SetToken(token) + + if err = sess.Save(); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to save :%w ", err)) + } + + return c.NoContent(http.StatusOK) +} diff --git a/router/responses_test.go b/router/responses_test.go index 1ea9845b..a8cd4bda 100644 --- a/router/responses_test.go +++ b/router/responses_test.go @@ -4,6 +4,8 @@ import ( "encoding/json" "errors" "fmt" + "github.com/traPtitech/anke-to/router/session/mock_session" + "github.com/traPtitech/anke-to/traq/mock_traq" "github.com/go-playground/validator/v10" @@ -524,6 +526,8 @@ func TestPostResponse(t *testing.T) { mockAdministrator := mock_model.NewMockIAdministrator(ctrl) mockQuestion := mock_model.NewMockIQuestion(ctrl) + mockStore := mock_session.NewMockIStore(ctrl) + mockUser := mock_traq.NewMockIUser(ctrl) r := NewResponse( mockQuestionnaire, @@ -537,6 +541,9 @@ func TestPostResponse(t *testing.T) { mockRespondent, mockQuestion, mockQuestionnaire, + mockStore, + mockUser, + ) // Questionnaire // GetQuestionnaireLimit @@ -642,6 +649,10 @@ func TestPostResponse(t *testing.T) { InsertResponses(gomock.Any(), responseIDFailure, gomock.Any()). Return(errMock).AnyTimes() + + mockStore.EXPECT(). + GetSession(gomock.Any()). + Return(nil,nil).AnyTimes() // responseID, err := mockRespondent. // InsertRespondent(string(userOne), 1, null.NewTime(nowTime, true)) // assertion.Equal(1, responseID) @@ -1063,7 +1074,8 @@ func TestGetResponse(t *testing.T) { mockAdministrator := mock_model.NewMockIAdministrator(ctrl) mockQuestion := mock_model.NewMockIQuestion(ctrl) - + mockStore := mock_session.NewMockIStore(ctrl) + mockUser := mock_traq.NewMockIUser(ctrl) r := NewResponse( mockQuestionnaire, mockValidation, @@ -1076,6 +1088,8 @@ func TestGetResponse(t *testing.T) { mockRespondent, mockQuestion, mockQuestionnaire, + mockStore, + mockUser, ) // Respondent @@ -1093,6 +1107,10 @@ func TestGetResponse(t *testing.T) { GetRespondentDetail(gomock.Any(), responseIDNotFound). Return(model.RespondentDetail{}, model.ErrRecordNotFound).AnyTimes() + mockStore.EXPECT(). + GetSession(gomock.Any()). + Return(nil,nil).AnyTimes() + type request struct { user users responseID int @@ -1231,6 +1249,8 @@ func TestEditResponse(t *testing.T) { mockAdministrator := mock_model.NewMockIAdministrator(ctrl) mockQuestion := mock_model.NewMockIQuestion(ctrl) + mockStore := mock_session.NewMockIStore(ctrl) + mockUser := mock_traq.NewMockIUser(ctrl) r := NewResponse( mockQuestionnaire, @@ -1244,7 +1264,10 @@ func TestEditResponse(t *testing.T) { mockRespondent, mockQuestion, mockQuestionnaire, + mockStore, + mockUser, ) + // Questionnaire // GetQuestionnaireLimit // success @@ -1363,6 +1386,11 @@ func TestEditResponse(t *testing.T) { DeleteResponse(gomock.Any(), responseIDFailure). Return(model.ErrNoRecordDeleted).AnyTimes() + + mockStore.EXPECT(). + GetSession(gomock.Any()). + Return(nil,nil).AnyTimes() + // responseID, err := mockRespondent. // InsertRespondent(string(userOne), 1, null.NewTime(nowTime, true)) // assertion.Equal(1, responseID) diff --git a/router/session/error.go b/router/session/error.go new file mode 100644 index 00000000..a8eeb9d4 --- /dev/null +++ b/router/session/error.go @@ -0,0 +1,10 @@ +package session + +import "errors" + +var ( + ErrNoSession = errors.New("no session") + ErrNoValue = errors.New("no value") +) + + diff --git a/router/session/session.go b/router/session/session.go new file mode 100644 index 00000000..b2bba2bb --- /dev/null +++ b/router/session/session.go @@ -0,0 +1,10 @@ +//go:generate mockgen -source=$GOFILE -destination=mock_$GOPACKAGE/mock_$GOFILE + +package session + +import "github.com/labstack/echo/v4" + +type IStore interface { + GetMiddleware() echo.MiddlewareFunc + GetSession(c echo.Context) (*Session,error) +} diff --git a/router/session/session_impl.go b/router/session/session_impl.go new file mode 100644 index 00000000..e07265f5 --- /dev/null +++ b/router/session/session_impl.go @@ -0,0 +1,129 @@ +package session + +import ( + "fmt" + "time" + + "github.com/gorilla/sessions" + "github.com/labstack/echo-contrib/session" + "github.com/labstack/echo/v4" + "github.com/srinathgs/mysqlstore" + "github.com/traPtitech/anke-to/model" + "golang.org/x/oauth2" +) + +type Store struct { + store *mysqlstore.MySQLStore +} + +func (s *Store) GetMiddleware() echo.MiddlewareFunc { + return session.Middleware(s.store) +} + +func (s *Store) GetSession(c echo.Context) (*Session, error) { + sess, err := session.Get("sessions", c) + if err != nil { + return nil, fmt.Errorf("failed to get session:%w", err) + } + + return &Session{ + c: c, + sess: sess, + }, nil +} + +func NewStore(sess *model.Session) (*Store, error) { + store, err := sess.Get() + if err != nil { + return nil, fmt.Errorf("failed to get session: %w", err) + } + return &Store{store: store}, nil +} + +type Session struct { + c echo.Context + sess *sessions.Session +} + +func (s *Session) SetUserID(userID string) { + s.sess.Values["userID"] = userID +} + +func (s *Session) GetUserID() (string, error) { + userID, ok := s.sess.Values["userID"].(string) + if !ok || userID == "" { + return "", ErrNoValue + } + + return userID, nil +} + +func (s *Session) SetVerifier(verifier string) { + s.sess.Values["verifier"] = verifier +} + +func (s *Session) GetVerifier() (string, error) { + verifier, ok := s.sess.Values["verifier"].(string) + if !ok || verifier == "" { + return "", ErrNoValue + } + + return verifier, nil +} + +func (s *Session) SetToken(token *oauth2.Token) { + s.sess.Values["access_token"] = token.AccessToken + s.sess.Values["token_type"] = token.TokenType + s.sess.Values["refresh_token"] = token.RefreshToken + s.sess.Values["expiry"] = token.Expiry +} + +func (s *Session) GetToken() (*oauth2.Token, error) { + iAccessToken, ok := s.sess.Values["access_token"] + if !ok || iAccessToken == nil { + return nil, ErrNoValue + } + + iTokenType, ok := s.sess.Values["token_type"] + if !ok || iTokenType == nil { + return nil, ErrNoValue + } + + iRefreshToken, ok := s.sess.Values["refresh_token"] + if !ok || iRefreshToken == nil { + return nil, ErrNoValue + } + + iExpiry, ok := s.sess.Values["expiry"] + if !ok || iExpiry == nil { + return nil, ErrNoValue + } + + return &oauth2.Token{ + AccessToken: iAccessToken.(string), + TokenType: iTokenType.(string), + RefreshToken: iRefreshToken.(string), + Expiry: iExpiry.(time.Time), + }, nil +} + +func (s *Session) SetState(state string) { + s.sess.Values["state"] = state +} + +func (s *Session) GetState() (string, error) { + state, ok := s.sess.Values["state"].(string) + if !ok || state == "" { + return "", ErrNoValue + } + + return state, nil +} + +func (s *Session) Save() error { + if err := s.sess.Save(s.c.Request(),s.c.Response());err!=nil { + return err + } + + return nil +} diff --git a/router/users_test.go b/router/users_test.go index c4d77032..dc339591 100644 --- a/router/users_test.go +++ b/router/users_test.go @@ -4,6 +4,8 @@ import ( "encoding/json" "fmt" "github.com/go-playground/validator/v10" + "github.com/traPtitech/anke-to/router/session/mock_session" + "github.com/traPtitech/anke-to/traq/mock_traq" "net/http" "testing" "time" @@ -162,7 +164,8 @@ func TestGetUsersMe(t *testing.T) { mockAdministrator := mock_model.NewMockIAdministrator(ctrl) mockQuestion := mock_model.NewMockIQuestion(ctrl) - + mockStore := mock_session.NewMockIStore(ctrl) + mockUser := mock_traq.NewMockIUser(ctrl) u := NewUser( mockRespondent, mockQuestionnaire, @@ -174,6 +177,8 @@ func TestGetUsersMe(t *testing.T) { mockRespondent, mockQuestion, mockQuestionnaire, + mockStore, + mockUser, ) type request struct { @@ -305,19 +310,16 @@ func TestGetMyResponses(t *testing.T) { mockAdministrator := mock_model.NewMockIAdministrator(ctrl) mockQuestion := mock_model.NewMockIQuestion(ctrl) - + mockStore := mock_session.NewMockIStore(ctrl) + mockUser := mock_traq.NewMockIUser(ctrl) u := NewUser( mockRespondent, mockQuestionnaire, mockTarget, mockAdministrator, ) - m := NewMiddleware( - mockAdministrator, - mockRespondent, - mockQuestion, - mockQuestionnaire, - ) + + m := NewMiddleware(mockAdministrator, mockRespondent, mockQuestion, mockQuestionnaire, mockStore, mockUser) // Respondent // GetRespondentInfos @@ -462,19 +464,15 @@ func TestGetMyResponsesByID(t *testing.T) { mockAdministrator := mock_model.NewMockIAdministrator(ctrl) mockQuestion := mock_model.NewMockIQuestion(ctrl) - + mockStore := mock_session.NewMockIStore(ctrl) + mockUser := mock_traq.NewMockIUser(ctrl) + m := NewMiddleware(mockAdministrator, mockRespondent, mockQuestion, mockQuestionnaire, mockStore, mockUser) u := NewUser( mockRespondent, mockQuestionnaire, mockTarget, mockAdministrator, ) - m := NewMiddleware( - mockAdministrator, - mockRespondent, - mockQuestion, - mockQuestionnaire, - ) // Respondent // GetRespondentInfos @@ -634,7 +632,8 @@ func TestGetTargetedQuestionnaire(t *testing.T) { mockAdministrator := mock_model.NewMockIAdministrator(ctrl) mockQuestion := mock_model.NewMockIQuestion(ctrl) - + mockStore := mock_session.NewMockIStore(ctrl) + mockUser := mock_traq.NewMockIUser(ctrl) u := NewUser( mockRespondent, mockQuestionnaire, @@ -646,6 +645,8 @@ func TestGetTargetedQuestionnaire(t *testing.T) { mockRespondent, mockQuestion, mockQuestionnaire, + mockStore, + mockUser, ) // Questionnaire @@ -785,7 +786,8 @@ func TestGetTargettedQuestionnairesBytraQID(t *testing.T) { mockAdministrator := mock_model.NewMockIAdministrator(ctrl) mockQuestion := mock_model.NewMockIQuestion(ctrl) - + mockStore := mock_session.NewMockIStore(ctrl) + mockUser := mock_traq.NewMockIUser(ctrl) u := NewUser( mockRespondent, mockQuestionnaire, @@ -797,6 +799,8 @@ func TestGetTargettedQuestionnairesBytraQID(t *testing.T) { mockRespondent, mockQuestion, mockQuestionnaire, + mockStore, + mockUser, ) // Questionnaire diff --git a/router/utils.go b/router/utils.go new file mode 100644 index 00000000..7bc5d868 --- /dev/null +++ b/router/utils.go @@ -0,0 +1,13 @@ +package router + +import "math/rand" + +func RandomString(n int) string { + var letter = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + + b := make([]rune, n) + for i := range b { + b[i] = letter[rand.Intn(len(letter))] + } + return string(b) +} diff --git a/traq/user.go b/traq/user.go new file mode 100644 index 00000000..a2421e6b --- /dev/null +++ b/traq/user.go @@ -0,0 +1,9 @@ +//go:generate mockgen -source=$GOFILE -destination=mock_$GOPACKAGE/mock_$GOFILE + +package traq + +import "golang.org/x/oauth2" + +type IUser interface { + GetMyID(token *oauth2.Token) (string, error) +} diff --git a/traq/user_impl.go b/traq/user_impl.go new file mode 100644 index 00000000..1feacbbd --- /dev/null +++ b/traq/user_impl.go @@ -0,0 +1,47 @@ +package traq + +import ( + "encoding/json" + "fmt" + "golang.org/x/oauth2" + "net/http" +) + +type User struct { +} + +type UserRes struct { + ID string `json:"id"` + Name string `json:"name"` +} + +func (u *User) GetMyID(token *oauth2.Token) (string, error) { + path := "https://q.trap.jp/api/v3/users/me" + + req, err := http.NewRequest(http.MethodGet, path, nil) + if err != nil { + return "", fmt.Errorf("failed to create new req :%w", err) + } + + token.SetAuthHeader(req) + httpClient := http.DefaultClient + res, err := httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("failed to get http res :%w", err) + } + + if res.StatusCode != 200 { + return "", fmt.Errorf("failed to get http res :(status :%d): %w", res.StatusCode, res.Status) + } + + user := &UserRes{} + if err = json.NewDecoder(res.Body).Decode(user); err != nil { + return "", fmt.Errorf("failed to decode res:%w", err) + } + + return user.Name, nil +} + +func NewUser() *User { + return &User{} +} diff --git a/wire.go b/wire.go index c32e951f..09240bbd 100644 --- a/wire.go +++ b/wire.go @@ -7,6 +7,7 @@ import ( "github.com/google/wire" "github.com/traPtitech/anke-to/model" "github.com/traPtitech/anke-to/router" + "github.com/traPtitech/anke-to/router/session" "github.com/traPtitech/anke-to/traq" ) @@ -21,11 +22,13 @@ var ( targetBind = wire.Bind(new(model.ITarget), new(*model.Target)) validationBind = wire.Bind(new(model.IValidation), new(*model.Validation)) transactionBind = wire.Bind(new(model.ITransaction), new(*model.Transaction)) + storeBind = wire.Bind(new(session.IStore), new(*session.Store)) webhookBind = wire.Bind(new(traq.IWebhook), new(*traq.Webhook)) + userBind = wire.Bind(new(traq.IUser), new(*traq.User)) ) -func InjectAPIServer() *router.API { +func InjectAPIServer() (*router.API,error) { wire.Build( router.NewAPI, router.NewMiddleware, @@ -34,6 +37,7 @@ func InjectAPIServer() *router.API { router.NewResponse, router.NewResult, router.NewUser, + router.NewOauth, model.NewAdministrator, model.NewOption, model.NewQuestionnaire, @@ -44,7 +48,10 @@ func InjectAPIServer() *router.API { model.NewTarget, model.NewValidation, model.NewTransaction, + model.NewSession, + session.NewStore, traq.NewWebhook, + traq.NewUser, administratorBind, optionBind, questionnaireBind, @@ -56,7 +63,8 @@ func InjectAPIServer() *router.API { validationBind, transactionBind, webhookBind, + storeBind, + userBind, ) - - return nil + return nil,nil } diff --git a/wire_gen.go b/wire_gen.go index bede2c48..9d8af8bb 100644 --- a/wire_gen.go +++ b/wire_gen.go @@ -10,6 +10,7 @@ import ( "github.com/google/wire" "github.com/traPtitech/anke-to/model" "github.com/traPtitech/anke-to/router" + "github.com/traPtitech/anke-to/router/session" "github.com/traPtitech/anke-to/traq" ) @@ -19,12 +20,18 @@ import ( // Injectors from wire.go: -func InjectAPIServer() *router.API { +func InjectAPIServer() (*router.API, error) { administrator := model.NewAdministrator() respondent := model.NewRespondent() question := model.NewQuestion() questionnaire := model.NewQuestionnaire() - middleware := router.NewMiddleware(administrator, respondent, question, questionnaire) + modelSession := model.NewSession() + store, err := session.NewStore(modelSession) + if err != nil { + return nil, err + } + user := traq.NewUser() + middleware := router.NewMiddleware(administrator, respondent, question, questionnaire, store, user) target := model.NewTarget() option := model.NewOption() scaleLabel := model.NewScaleLabel() @@ -36,9 +43,10 @@ func InjectAPIServer() *router.API { response := model.NewResponse() routerResponse := router.NewResponse(questionnaire, validation, scaleLabel, respondent, response) result := router.NewResult(respondent, questionnaire, administrator) - user := router.NewUser(respondent, questionnaire, target, administrator) - api := router.NewAPI(middleware, routerQuestionnaire, routerQuestion, routerResponse, result, user) - return api + routerUser := router.NewUser(respondent, questionnaire, target, administrator) + oauth := router.NewOauth(store) + api := router.NewAPI(middleware, routerQuestionnaire, routerQuestion, routerResponse, result, routerUser, oauth) + return api, nil } // wire.go: @@ -54,6 +62,8 @@ var ( targetBind = wire.Bind(new(model.ITarget), new(*model.Target)) validationBind = wire.Bind(new(model.IValidation), new(*model.Validation)) transactionBind = wire.Bind(new(model.ITransaction), new(*model.Transaction)) + storeBind = wire.Bind(new(session.IStore), new(*session.Store)) webhookBind = wire.Bind(new(traq.IWebhook), new(*traq.Webhook)) + userBind = wire.Bind(new(traq.IUser), new(*traq.User)) )