diff --git a/protoc-gen-connect-python/generator/config.go b/protoc-gen-connect-python/generator/config.go index c949ae8..4ac8a33 100644 --- a/protoc-gen-connect-python/generator/config.go +++ b/protoc-gen-connect-python/generator/config.go @@ -34,6 +34,11 @@ type Config struct { // Imports is how to import dependencies in the generated code. Imports Imports + + // Async indicates whether to only generate asynchronous code. If false, + // only synchronous code will be generated. If nil, both synchronous and + // asynchronous code will be generated. + Async *bool } func parseConfig(p string) Config { @@ -64,6 +69,15 @@ func parseConfig(p string) Config { case "relative": cfg.Imports = ImportsRelative } + case "async": + switch value { + case "true": + trueVal := true + cfg.Async = &trueVal + case "false": + falseVal := false + cfg.Async = &falseVal + } } } return cfg diff --git a/protoc-gen-connect-python/generator/generator.go b/protoc-gen-connect-python/generator/generator.go index eaca278..fc94930 100644 --- a/protoc-gen-connect-python/generator/generator.go +++ b/protoc-gen-connect-python/generator/generator.go @@ -55,6 +55,13 @@ func generateConnectFile(fd protoreflect.FileDescriptor, conf Config) (string, s ModuleName: moduleName, Imports: importStatements(fd, conf), } + if conf.Async != nil { + if *conf.Async { + vars.SkipSync = true + } else { + vars.SkipAsync = true + } + } svcs := fd.Services() packageName := string(fd.Package()) @@ -109,7 +116,7 @@ func generateConnectFile(fd protoreflect.FileDescriptor, conf Config) (string, s vars.Services = append(vars.Services, connectSvc) } - var buf = &bytes.Buffer{} + buf := &bytes.Buffer{} err := ConnectTemplate.Execute(buf, vars) if err != nil { return "", "", fmt.Errorf("failed to execute template: %w", err) diff --git a/protoc-gen-connect-python/generator/generator_test.go b/protoc-gen-connect-python/generator/generator_test.go index ff9b2ac..24f2855 100644 --- a/protoc-gen-connect-python/generator/generator_test.go +++ b/protoc-gen-connect-python/generator/generator_test.go @@ -127,10 +127,11 @@ func TestGenerate(t *testing.T) { t.Parallel() tests := []struct { - name string - req *pluginpb.CodeGeneratorRequest - wantStrings []string - wantErr bool + name string + req *pluginpb.CodeGeneratorRequest + wantStrings []string + dontWantStrings []string + wantErr bool }{ { name: "empty request", @@ -195,7 +196,127 @@ func TestGenerate(t *testing.T) { }, }, wantErr: false, - wantStrings: []string{"def try_(self"}, + wantStrings: []string{"class TestServiceASGIApplication", "class TestServiceWSGIApplication"}, + }, + { + name: "async only", + req: &pluginpb.CodeGeneratorRequest{ + FileToGenerate: []string{"test.proto"}, + Parameter: proto.String("async=true"), + ProtoFile: []*descriptorpb.FileDescriptorProto{ + { + Name: proto.String("test.proto"), + Package: proto.String("test"), + Dependency: []string{"other.proto"}, + Service: []*descriptorpb.ServiceDescriptorProto{ + { + Name: proto.String("TestService"), + Method: []*descriptorpb.MethodDescriptorProto{ + { + Name: proto.String("TestMethod"), + InputType: proto.String(".test.TestRequest"), + OutputType: proto.String(".test.TestResponse"), + }, + { + Name: proto.String("TestMethod2"), + InputType: proto.String(".otherpackage.OtherRequest"), + OutputType: proto.String(".otherpackage.OtherResponse"), + }, + // Reserved keyword + { + Name: proto.String("Try"), + InputType: proto.String(".otherpackage.OtherRequest"), + OutputType: proto.String(".otherpackage.OtherResponse"), + }, + }, + }, + }, + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("TestRequest"), + }, + { + Name: proto.String("TestResponse"), + }, + }, + }, + { + Name: proto.String("other.proto"), + Package: proto.String("otherpackage"), + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("OtherRequest"), + }, + { + Name: proto.String("OtherResponse"), + }, + }, + }, + }, + }, + wantErr: false, + wantStrings: []string{"class TestServiceASGIApplication"}, + dontWantStrings: []string{"class TestServiceWSGIApplication"}, + }, + { + name: "sync only", + req: &pluginpb.CodeGeneratorRequest{ + FileToGenerate: []string{"test.proto"}, + Parameter: proto.String("async=false"), + ProtoFile: []*descriptorpb.FileDescriptorProto{ + { + Name: proto.String("test.proto"), + Package: proto.String("test"), + Dependency: []string{"other.proto"}, + Service: []*descriptorpb.ServiceDescriptorProto{ + { + Name: proto.String("TestService"), + Method: []*descriptorpb.MethodDescriptorProto{ + { + Name: proto.String("TestMethod"), + InputType: proto.String(".test.TestRequest"), + OutputType: proto.String(".test.TestResponse"), + }, + { + Name: proto.String("TestMethod2"), + InputType: proto.String(".otherpackage.OtherRequest"), + OutputType: proto.String(".otherpackage.OtherResponse"), + }, + // Reserved keyword + { + Name: proto.String("Try"), + InputType: proto.String(".otherpackage.OtherRequest"), + OutputType: proto.String(".otherpackage.OtherResponse"), + }, + }, + }, + }, + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("TestRequest"), + }, + { + Name: proto.String("TestResponse"), + }, + }, + }, + { + Name: proto.String("other.proto"), + Package: proto.String("otherpackage"), + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("OtherRequest"), + }, + { + Name: proto.String("OtherResponse"), + }, + }, + }, + }, + }, + wantErr: false, + wantStrings: []string{"class TestServiceWSGIApplication"}, + dontWantStrings: []string{"class TestServiceASGIApplication"}, }, } @@ -219,6 +340,11 @@ func TestGenerate(t *testing.T) { t.Errorf("generate() missing expected string: %v", s) } } + for _, s := range tt.dontWantStrings { + if strings.Contains(resp.GetFile()[0].GetContent(), s) { + t.Errorf("generate() contains unexpected string: %v", s) + } + } } }) } diff --git a/protoc-gen-connect-python/generator/template.go b/protoc-gen-connect-python/generator/template.go index 8a8d32e..2add65c 100644 --- a/protoc-gen-connect-python/generator/template.go +++ b/protoc-gen-connect-python/generator/template.go @@ -13,6 +13,8 @@ type ConnectTemplateVariables struct { ModuleName string Imports []ImportStatement Services []*ConnectService + SkipAsync bool + SkipSync bool } type ConnectService struct { @@ -61,9 +63,9 @@ from connectrpc.server import ConnectASGIApplication, ConnectWSGIApplication, En {{if .Relative}}from . import {{.Name}}{{else}}import {{.Name}}{{end}} as {{.Alias}} {{- end}} {{- end}} -{{- range .Services}} - +{{if not .SkipAsync }} +{{- range .Services}} class {{.Name}}(Protocol):{{- range .Methods }} {{if not .ResponseStream }}async {{end}}def {{.PythonName}}(self, request: {{if .RequestStream}}AsyncIterator[{{end}}{{.InputType}}{{if .RequestStream}}]{{end}}, ctx: RequestContext) -> {{if .ResponseStream}}AsyncIterator[{{end}}{{.OutputType}}{{if .ResponseStream}}]{{end}}: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") @@ -124,6 +126,9 @@ class {{.Name}}Client(ConnectClient):{{range .Methods}} {{- end}} ) {{end}}{{- end }} +{{end}} + +{{if not .SkipSync }} {{range .Services}} class {{.Name}}Sync(Protocol):{{- range .Methods }} def {{.PythonName}}(self, request: {{if .RequestStream}}Iterator[{{end}}{{.InputType}}{{if .RequestStream}}]{{end}}, ctx: RequestContext) -> {{if .ResponseStream}}Iterator[{{end}}{{.OutputType}}{{if .ResponseStream}}]{{end}}: @@ -184,4 +189,6 @@ class {{.Name}}ClientSync(ConnectClientSync):{{range .Methods}} use_get=use_get, {{- end}} ) -{{end}}{{end}}`)) +{{end}}{{end}} +{{end}} +`))