Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions api/dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package api
import (
"errors"
"fmt"
"io"
"net/http"

"github.com/Yiling-J/tablepilot/ent"
Expand All @@ -12,8 +11,6 @@ import (
"github.com/gin-gonic/gin"
)

// Dataset-related handlers

func (hs *HTTPServer) CreateDataset(ctx *gin.Context) {
var apiReq services_dataset.DatasetAPIRequest
if err := ctx.ShouldBind(&apiReq); err != nil {
Expand All @@ -28,19 +25,22 @@ func (hs *HTTPServer) CreateDataset(ctx *gin.Context) {
Data: apiReq.Data,
}

if apiReq.Type == "csv" {
if apiReq.Type == "csv" || apiReq.Type == "image" {
if len(apiReq.Files) == 0 {
errorResponse(ctx, http.StatusBadRequest, errors.New("at least one file is required for CSV dataset type"))
return
}
var readers []io.Reader
var readers []services_dataset.CreateDatasetFile
for _, fh := range apiReq.Files {
f, err := fh.Open()
if err != nil {
errorResponse(ctx, http.StatusBadRequest, err)
return
}
readers = append(readers, f)
readers = append(readers, services_dataset.CreateDatasetFile{
Name: fh.Filename,
Reader: f,
})
}
serviceReq.Files = readers
}
Expand Down Expand Up @@ -107,19 +107,22 @@ func (hs *HTTPServer) UpdateDataset(ctx *gin.Context) {
}

if apiReq.Files != nil {
var readers []io.Reader
var readers []services_dataset.CreateDatasetFile
for _, fh := range apiReq.Files {
f, err := fh.Open()
if err != nil {
errorResponse(ctx, http.StatusBadRequest, err)
return
}
readers = append(readers, f)
readers = append(readers, services_dataset.CreateDatasetFile{
Name: fh.Filename,
Reader: f,
})
}
serviceReq.Files = readers
serviceReq.Fields = append(serviceReq.Fields, "files")
} else {
serviceReq.Files = []io.Reader{}
serviceReq.Files = []services_dataset.CreateDatasetFile{}
}

err := hs.DatasetService.Update(ctx.Request.Context(), datasetID, serviceReq)
Expand Down
2 changes: 1 addition & 1 deletion api/dataset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func TestAPI_CreateDatasetWithFiles(t *testing.T) {
require.Equal(t, expectedRequest.Description, req.Description)
require.Equal(t, expectedRequest.Type, req.Type)
require.Equal(t, 1, len(req.Files))
data, err := io.ReadAll(req.Files[0])
data, err := io.ReadAll(req.Files[0].Reader)
require.NoError(t, err)
require.Equal(t, "header1,header2,header3\nr1c1,r1c2,r1c3\n", string(data))
return "new_dataset_id", nil
Expand Down
33 changes: 18 additions & 15 deletions cmd/cli/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,17 @@ func (h *Handler) CreateDataset(cmd *cobra.Command, args []string) error {
options = append(options, o...)
}
req.Data = options
case "csv":
case "csv", "image":
filePaths, err := cmd.Flags().GetStringArray("path")
if err != nil {
return fmt.Errorf("error getting file flag for csv type: %w", err)
}
if len(filePaths) == 0 {
return fmt.Errorf("at least one --path must be provided for type 'csv'")
return fmt.Errorf("at least one --path must be provided")
}
var readers []io.Reader
var readers []dataset.CreateDatasetFile
files, err := parsePaths(filePaths)
names := []string{}
if err != nil {
return err
}
Expand All @@ -199,12 +200,17 @@ func (h *Handler) CreateDataset(cmd *cobra.Command, args []string) error {
if err != nil {
return fmt.Errorf("failed to open file %s: %w", f, err)
}
readers = append(readers, file)
readers = append(readers, dataset.CreateDatasetFile{
Name: filepath.Base(f),
Reader: file,
})
names = append(names, filepath.Base(f))
}
req.Files = readers
req.Data = names
defer func() {
for _, f := range readers {
if c, ok := f.(io.Closer); ok {
if c, ok := f.Reader.(io.Closer); ok {
c.Close()
}
}
Expand Down Expand Up @@ -307,32 +313,29 @@ func (h *Handler) UpdateDataset(cmd *cobra.Command, args []string) error {
options = append(options, o...)
}
req.Data = options
case "csv":
case "csv", "image":
filePaths, err := cmd.Flags().GetStringArray("file")
if err != nil {
return fmt.Errorf("error getting file flag for csv type: %w", err)
}
if len(filePaths) == 0 {
return fmt.Errorf("at least one --file path must be provided for type 'csv'")
}
var files []io.Reader
var files []dataset.CreateDatasetFile
for _, filePath := range filePaths {
file, err := os.Open(filePath)
if err != nil {
// Close already opened files if any
for _, f := range files {
if c, ok := f.(io.Closer); ok {
c.Close()
}
}
return fmt.Errorf("failed to open file %s: %w", filePath, err)
}
files = append(files, file)
files = append(files, dataset.CreateDatasetFile{
Name: filepath.Base(filePath),
Reader: file,
})
}
req.Files = files
defer func() {
for _, f := range files {
if c, ok := f.(io.Closer); ok {
if c, ok := f.Reader.(io.Closer); ok {
c.Close()
}
}
Expand Down
7 changes: 4 additions & 3 deletions ent/dataset/dataset.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion ent/migrate/schema.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion ent/schema/dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (Dataset) Fields() []ent.Field {
field.String("name").Unique().NotEmpty(),
field.String("path").Optional(),
field.String("description").Default(""),
field.Enum("type").Values("list", "csv"),
field.Enum("type").Values("list", "csv", "image"),
field.JSON("indexer", CSVIndexer{}).Optional(),
field.Strings("values").Optional(),
}
Expand Down
104 changes: 58 additions & 46 deletions services/dataset/dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ import (
"io"
"os"
"path/filepath"
"slices"

"github.com/Yiling-J/tablepilot/config"
"github.com/Yiling-J/tablepilot/ent"
db_dataset "github.com/Yiling-J/tablepilot/ent/dataset"
"github.com/Yiling-J/tablepilot/services/source"
"github.com/Yiling-J/tablepilot/services/source/csvindexer"
"github.com/Yiling-J/tablepilot/utils"
)

//go:generate moq -rm -out dataset_moq.go . DatasetService
Expand All @@ -39,56 +39,43 @@ func NewDatasetService(db *ent.Client, cfg *config.Config) *DatasetServiceImpl {
}

func (s DatasetServiceImpl) buildCreateDatasetReq(ctx context.Context, req *CreateDatasetRequest, sr *ent.Dataset) error {
switch req.Type {
case db_dataset.TypeCsv:
relativePath := filepath.Join("datasets/shared", sr.Nanoid)
dirPath := filepath.Join(s.cfg.Common.DataDir, relativePath)
err := os.MkdirAll(dirPath, os.ModePerm)
if err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
relativePath := filepath.Join("datasets/shared", sr.Nanoid)
dirPath := filepath.Join(s.cfg.Common.DataDir, relativePath)
err := os.MkdirAll(dirPath, os.ModePerm)
if err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}

if len(req.Files) == 0 {
return errors.New("dataset.Create: files should not be empty")
}
filePath := filepath.Join(dirPath, "data.csv")
outFile, err := os.Create(filePath)
for _, file := range req.Files {
outFile, err := os.Create(filepath.Join(dirPath, file.Name))
if err != nil {
return fmt.Errorf("failed to create file %s: %w", filePath, err)
return fmt.Errorf("failed to create file %w", err)
}
for i, file := range req.Files {
// skip csv headers
if i > 0 {
reader := utils.NewCsvReader(file)
_, err = reader.Read()
if err != nil {
return fmt.Errorf("failed to read csv %w", err)
}
offset := reader.InputOffset()
_, err = file.(io.ReadSeeker).Seek(offset, io.SeekStart)
if err != nil {
return fmt.Errorf("failed to seek csv file %w", err)
}
}
_, err = io.Copy(outFile, file)
if err != nil {
return fmt.Errorf("failed to write to file %s: %w", filePath, err)
}
defer outFile.Close()
_, err = io.Copy(outFile, file.Reader)
if err != nil {
return fmt.Errorf("failed to write to file %w", err)
}
outFile.Close()
// build index
indexer, err := csvindexer.NewCSVIndexer(os.DirFS(dirPath), []string{"data.csv"})
}
switch req.Type {
case db_dataset.TypeCsv:
indexer, err := csvindexer.NewCSVIndexer(os.DirFS(dirPath), req.Data)
if err != nil {
return fmt.Errorf("table.Create: build csv index: %w", err)
}
err = sr.Update().SetPath(relativePath).SetIndexer(indexer.CSVIndexer).Exec(ctx)
err = sr.Update().SetPath(relativePath).SetIndexer(indexer.CSVIndexer).SetValues(req.Data).Exec(ctx)
if err != nil {
return fmt.Errorf("table.Create: update dataset metadata: %w", err) // Clarified error
return fmt.Errorf("table.Create: update dataset metadata: %w", err)
}
case db_dataset.TypeImage:
err = sr.Update().SetPath(relativePath).SetValues(req.Data).Exec(ctx)
if err != nil {
return fmt.Errorf("table.Create: update dataset metadata: %w", err)
}
case db_dataset.TypeList:
err := sr.Update().SetValues(req.Data).Exec(ctx)
if err != nil {
return fmt.Errorf("table.Create: update dataset values: %w", err) // Clarified error
return fmt.Errorf("table.Create: update dataset values: %w", err)
}
}
return nil
Expand Down Expand Up @@ -119,6 +106,10 @@ func (s *DatasetServiceImpl) List(ctx context.Context) ([]*DatasetInfo, error) {
}
datasetInfos := []*DatasetInfo{}
for _, ds := range datasets {
// backward compatible
if ds.Type == db_dataset.TypeCsv && len(ds.Values) == 0 {
ds.Values = []string{"data.csv"}
}
datasetInfos = append(datasetInfos, &DatasetInfo{
ID: ds.Nanoid,
Name: ds.Name,
Expand Down Expand Up @@ -161,22 +152,38 @@ func (s *DatasetServiceImpl) Update(ctx context.Context, dataset string, req *Up
case "description":
updater.SetDescription(req.Description)
case "data", "files":
processDataRebuild = true
updater.ClearIndexer().ClearPath().SetValues(nil)
// new files or data slice change
if len(req.Files) > 0 || !slices.Equal(req.Data, ds.Values) {
processDataRebuild = true
updater.ClearIndexer().ClearPath().SetValues(nil)
}
}
}

updatedDsEntity, err := updater.Save(ctx)
if err != nil {
return ent.Rollback(tx, fmt.Errorf("dataset.Update: save changes: %w", err)) // Clarified error
return ent.Rollback(tx, fmt.Errorf("dataset.Update: save changes: %w", err))
}

if processDataRebuild {
if originalPath != "" {
oldDirPath := filepath.Join(s.cfg.Common.DataDir, originalPath)
if _, statErr := os.Stat(oldDirPath); !os.IsNotExist(statErr) {
if removeErr := os.RemoveAll(oldDirPath); removeErr != nil {
return ent.Rollback(tx, fmt.Errorf("dataset.Update: failed to remove old directory %s: %w", oldDirPath, removeErr))
keep := map[string]bool{}
for _, file := range req.Data {
keep[file] = true
}
entries, err := os.ReadDir(oldDirPath)
if err != nil {
return ent.Rollback(tx, fmt.Errorf("dataset.Update: read dir: %w", err))
}
for _, e := range entries {
if _, ok := keep[e.Name()]; !ok {
err = os.Remove(filepath.Join(oldDirPath, e.Name()))
if err != nil {
return ent.Rollback(tx, fmt.Errorf("dataset.Update: remove file: %w", err))
}
}
}
}
}
Expand Down Expand Up @@ -217,7 +224,7 @@ func (s *DatasetServiceImpl) Delete(ctx context.Context, dataset string) error {

err = tx.Dataset.DeleteOne(ds).Exec(ctx)
if err != nil {
return ent.Rollback(tx, fmt.Errorf("dataset.Delete: execute delete: %w", err)) // Clarified error
return ent.Rollback(tx, fmt.Errorf("dataset.Delete: execute delete: %w", err))
}

return tx.Commit()
Expand All @@ -229,7 +236,7 @@ func (s *DatasetServiceImpl) Get(ctx context.Context, source string) (*DatasetIn
db_dataset.Nanoid(source),
)).Only(ctx)
if err != nil {
return nil, fmt.Errorf("dataset.Get: query dataset: %w", err) // Clarified error
return nil, fmt.Errorf("dataset.Get: query dataset: %w", err)
}
return &DatasetInfo{
Name: sr.Name,
Expand Down Expand Up @@ -276,6 +283,11 @@ func (s *DatasetServiceImpl) Preview(ctx context.Context, dataset string) (*Data
Type: sr.Type,
Rows: rows,
}, nil
case db_dataset.TypeImage:
return &DatasetRows{
Type: sr.Type,
Data: sr.Values,
}, nil
case db_dataset.TypeList:
return &DatasetRows{
Type: sr.Type,
Expand Down
Loading