Skip to content

opt:(write) support format codes #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ go 1.23.0

require (
github.com/Knetic/govaluate v3.0.0+incompatible
github.com/davecgh/go-spew v1.1.1
github.com/sourcegraph/go-lsp v0.0.0-20240223163137-f80c5dd31dfd
github.com/sourcegraph/jsonrpc2 v0.2.0
github.com/stretchr/testify v1.10.0
Expand All @@ -13,6 +12,7 @@ require (
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/sync v0.13.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down
9 changes: 5 additions & 4 deletions lang/golang/parser/pkg.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"go/types"
"os"
"path/filepath"
"strconv"
"strings"

. "github.com/cloudwego/abcoder/lang/uniast"
Expand All @@ -33,15 +34,15 @@ func (p *GoParser) parseImports(fset *token.FileSet, file []byte, mod *Module, i
sysImports := make(map[string]string)
ret := &importInfo{}
for _, imp := range impts {
importPath := imp.Path.Value[1 : len(imp.Path.Value)-1] // remove the quotes
importPath, _ := strconv.Unquote(imp.Path.Value) // remove the quotes
importAlias := ""
// Check if user has defined an alias for current import
if imp.Name != nil {
importAlias = imp.Name.Name // update the alias
ret.Origins = append(ret.Origins, Import{Path: imp.Path.Value, Alias: &importAlias})
ret.Origins = append(ret.Origins, Import{Path: importPath, Alias: &importAlias})
} else {
importAlias = getPackageAlias(importPath)
ret.Origins = append(ret.Origins, Import{Path: imp.Path.Value})
ret.Origins = append(ret.Origins, Import{Path: importPath})
}

// Fix: module name may also be like this?
Expand Down Expand Up @@ -212,7 +213,7 @@ func (p *GoParser) loadPackages(mod *Module, dir string, pkgPath PkgPath) (err e
mod.Files[relpath] = f
}
pkgid := pkg.ID
f.Package = &pkgid
f.Package = []PkgPath{pkgid}
f.Imports = imports.Origins
if err := p.parseFile(ctx, file); err != nil {
return err
Expand Down
3 changes: 2 additions & 1 deletion lang/golang/writer/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package writer

import (
"strconv"
"strings"

"github.com/cloudwego/abcoder/lang/uniast"
Expand Down Expand Up @@ -44,7 +45,7 @@ func writeSingleImport(sb *strings.Builder, v uniast.Import) {
sb.WriteString(*v.Alias)
sb.WriteString(" ")
}
sb.WriteString(v.Path)
sb.WriteString(strconv.Quote(v.Path))
sb.WriteString("\n")
}

Expand Down
63 changes: 47 additions & 16 deletions lang/golang/writer/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package writer

import (
"bytes"
"context"
"fmt"
"go/ast"
"go/parser"
Expand All @@ -29,7 +31,6 @@ import (
"strconv"
"strings"

"github.com/cloudwego/abcoder/lang/log"
"github.com/cloudwego/abcoder/lang/uniast"
"github.com/cloudwego/abcoder/lang/utils"
)
Expand Down Expand Up @@ -180,12 +181,6 @@ func (w *Writer) WriteModule(repo *uniast.Repository, modPath string, outDir str
return fmt.Errorf("write go.mod failed: %v", err)
}

// go mod tidy
cmd := exec.Command(w.Options.CompilerPath, "mod", "tidy")
cmd.Dir = outdir
if err := cmd.Run(); err != nil {
log.Error("go mod tidy failed: %v", err)
}
return nil
}

Expand Down Expand Up @@ -263,7 +258,7 @@ func (w *Writer) appendNode(node *uniast.Node, pkg string, isMain bool, file str
if v.PkgPath == "" || v.PkgPath == pkg {
continue
}
fs.impts = append(fs.impts, uniast.Import{Path: strconv.Quote(v.PkgPath)})
fs.impts = append(fs.impts, uniast.Import{Path: v.PkgPath})
}

// 检查是否有imports
Expand All @@ -283,18 +278,24 @@ func (w *Writer) appendNode(node *uniast.Node, pkg string, isMain bool, file str

// receive a piece of golang code, parse it and splits the imports and codes
func (w Writer) SplitImportsAndCodes(src string) (codes string, imports []uniast.Import, err error) {
var src2 = src
if !strings.Contains("package ", src) {
src2 = "package main\n\n" + src
}
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, "", src, parser.SkipObjectResolution)
f, err := parser.ParseFile(fset, "", src2, parser.SkipObjectResolution)
if err != nil {
// NOTICE: if parse failed, just return the src
return src, nil, nil
}
for _, imp := range f.Imports {
var alias string
s, _ := strconv.Unquote(imp.Path.Value)
v := uniast.Import{Path: s}
if imp.Name != nil {
alias = imp.Name.Name
tmp := imp.Name.Name
v.Alias = &tmp
}
imports = append(imports, uniast.Import{Path: imp.Path.Value, Alias: &alias})
imports = append(imports, v)
}
start := 0
for _, s := range f.Decls {
Expand All @@ -304,11 +305,11 @@ func (w Writer) SplitImportsAndCodes(src string) (codes string, imports []uniast
start = fset.Position(s.Pos()).Offset
break
}
return src[start:], imports, nil
return src2[start:], imports, nil
}

func (w *Writer) IdToImport(id uniast.Identity) (uniast.Import, error) {
return uniast.Import{Path: strconv.Quote(id.PkgPath)}, nil
return uniast.Import{Path: id.PkgPath}, nil
}

func (p *Writer) PatchImports(impts []uniast.Import, file []byte) ([]byte, error) {
Expand All @@ -321,8 +322,9 @@ func (p *Writer) PatchImports(impts []uniast.Import, file []byte) ([]byte, error

old := make([]uniast.Import, 0, len(f.Imports))
for _, imp := range f.Imports {
v, _ := strconv.Unquote(imp.Path.Value)
i := uniast.Import{
Path: imp.Path.Value,
Path: v,
}
if imp.Name != nil {
tmp := imp.Name.Name
Expand Down Expand Up @@ -364,7 +366,7 @@ func (p *Writer) CreateFile(fi *uniast.File, mod *uniast.Module) ([]byte, error)
sb.WriteString("package ")
pkgName := filepath.Base(filepath.Dir(fi.Path))
if fi.Package != nil {
pkg := mod.Packages[*fi.Package]
pkg := mod.Packages[fi.Package[0]]
if pkg != nil {
if pkg.IsMain {
pkgName = "main"
Expand All @@ -386,3 +388,32 @@ func (p *Writer) CreateFile(fi *uniast.File, mod *uniast.Module) ([]byte, error)
bs := sb.String()
return []byte(bs), nil
}

func (p *Writer) Format(ctx context.Context, path string) error {
fi, err := os.Stat(path)
if err != nil {
return fmt.Errorf("stat %s failed: %v", path, err)
}

// call goimports
if err := utils.ExecCmdWithInstall(ctx, "goimports", []string{"-w", path}, p.CompilerPath, []string{"install", "golang.org/x/tools/cmd/goimports@latest"}); err != nil {
return fmt.Errorf("goimports failed: %v", err)
}
// call gofmt
if err := utils.ExecCmdWithInstall(ctx, "gofmt", []string{"-w", path}, p.CompilerPath, []string{"install", "golang.org/x/tools/cmd/gofmt@latest"}); err != nil {
return fmt.Errorf("gofmt failed: %v", err)
}
// call go mod tidy
cmd := exec.CommandContext(ctx, p.CompilerPath, "mod", "tidy")
cmd.Dir = path
if !fi.IsDir() {
cmd.Dir = filepath.Dir(path)
}
buf := bytes.NewBuffer(nil)
cmd.Stderr = buf
cmd.Stdout = buf
if err := cmd.Run(); err != nil {
return fmt.Errorf("go mod tidy failed: %v\n%s", err, buf.String())
}
return nil
}
2 changes: 1 addition & 1 deletion lang/golang/writer/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ import "fmt"
file: &uniast.File{
Imports: []uniast.Import{
{
Path: `"runtime"`,
Path: `runtime`,
Alias: &alias1,
},
},
Expand Down
25 changes: 17 additions & 8 deletions lang/patch/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package patch

import (
"context"
"fmt"
"math"
"os"
Expand Down Expand Up @@ -124,7 +125,7 @@ next_dep:
for _, dep := range patch.AddedDeps {
impt, err := w.IdToImport(dep)
if err != nil {
return fmt.Errorf("convert identity %s to import failed: %v", dep.Full(), err)
return utils.WrapError(err, "convert identity %s to import failed", dep.Full())
}
f.Imports = uniast.InserImport(f.Imports, impt)
}
Expand All @@ -135,7 +136,7 @@ next_dep:
File: f,
}
if err := p.patch(n); err != nil {
return fmt.Errorf("patch file %s failed: %v", f.Path, err)
return utils.WrapError(err, "patch file %s failed", f.Path)
}
return nil
}
Expand Down Expand Up @@ -168,7 +169,7 @@ func (p *Patcher) Flush() error {
fi := mod.GetFile(fpath)
data, err = writer.CreateFile(fi, mod)
if err != nil {
return fmt.Errorf("create file %s failed: %v", fpath, err)
return utils.WrapError(err, "create file %s failed", fpath)
}
}

Expand All @@ -189,7 +190,7 @@ func (p *Patcher) Flush() error {
}

if err := utils.MustWriteFile(filepath.Join(p.OutDir, fpath), data); err != nil {
return fmt.Errorf("write file %s failed: %v", fpath, err)
return utils.WrapError(err, "write file %s failed", fpath)
}

// patch imports
Expand All @@ -199,12 +200,13 @@ func (p *Patcher) Flush() error {
if mod == nil {
return fmt.Errorf("module %s not found", n.Identity.ModPath)
}
n.File.RemoveUnusedImports(p.repo)
data, err := writer.PatchImports(n.File.Imports, data)
if err != nil {
return fmt.Errorf("patch imports failed: %v", err)
return utils.WrapError(err, "patch imports failed")
}
if err := utils.MustWriteFile(filepath.Join(p.OutDir, fpath), data); err != nil {
return fmt.Errorf("write file %s failed: %v", fpath, err)
return utils.WrapError(err, "write file %s failed: %v", fpath)
}
}
}
Expand All @@ -218,13 +220,20 @@ func (p *Patcher) Flush() error {
fpath := filepath.Join(p.RepoDir, f.Path)
bs, err := os.ReadFile(fpath)
if err != nil {
return fmt.Errorf("read file %s failed: %v", fpath, err)
return utils.WrapError(err, "read file %s failed", fpath)
}
fpath = filepath.Join(p.OutDir, f.Path)
if err := utils.MustWriteFile(fpath, bs); err != nil {
return fmt.Errorf("write file %s failed: %v", fpath, err)
return utils.WrapError(err, "write file %s failed", fpath)
}
}
w := p.getLangWriter(mod.Language)
if w == nil {
return fmt.Errorf("unsupported language %s writer", mod.Language)
}
if err := w.Format(context.Background(), p.OutDir); err != nil {
return utils.WrapError(err, "format file %s failed", p.OutDir)
}
}
return nil
}
Expand Down
29 changes: 25 additions & 4 deletions lang/uniast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,37 @@ func NewRepository(name string) Repository {

type File struct {
Path string
Imports []Import `json:",omitempty"`
Package *PkgPath `json:",omitempty"`
Imports []Import `json:",omitempty"`
Package []PkgPath `json:",omitempty"` // related packages, maybe one (belong to) or many (children)
Nodes []Identity `json:",omitempty"`
}

func (f *File) RemoveUnusedImports(repo *Repository) {
marked := make(map[string]bool, len(f.Imports))
for _, id := range f.Nodes {
node := repo.GetNode(id)
if node == nil {
continue
}
for _, dep := range node.Dependencies {
marked[dep.Identity.PkgPath] = true
}
}
final := make([]Import, 0, len(f.Imports))
for i := len(f.Imports) - 1; i >= 0; i-- {
if marked[f.Imports[i].Path] {
final = InserImport(final, f.Imports[i])
}
}
f.Imports = final
}

type Import struct {
Alias *string `json:",omitempty"`
Path string
Path PkgPath
}

func NewImport(alias *string, path string) Import {
func NewImport(alias *string, path PkgPath) Import {
return Import{
Alias: alias,
Path: path,
Expand Down
Loading