diff --git a/writer.go b/writer.go index 231cfce..16a7ea6 100644 --- a/writer.go +++ b/writer.go @@ -5,7 +5,6 @@ import ( "log" "os" "path" - "sort" "strings" "github.com/CalebQ42/squashfs/internal/inode" @@ -17,10 +16,10 @@ type Writer struct { files map[string][]*File directories []string symlinkTable map[string]string //symlinkTable holds info about symlink'd to files that had to be moved from their original position. [originalpath]newpath + symTableTemp map[string]string ResolveSymlinks bool AllowErrors bool compression int - temp []*File } //NewWriter creates a new squashfs.Writer with the default settings (gzip compression, autoresolving symlinks, and allowErrors) @@ -51,87 +50,102 @@ func NewWriterWithOptions(resolveSymlinks, allowErrors bool, compressionType int return &out, nil } -//convertFile converts the given os.File to a squashfs.File and then adds it to the Writer's temp File slice. -func (w *Writer) convertFile(squashfsPath string, file *os.File, errChan chan error) { +type fileError struct { + files []*File + err error +} + +//convertFile converts the given os.File to a squashfs.File. Returns the errors and converted file to the channels. +func (w *Writer) convertFile(squashfsPath string, file *os.File, subDir bool, fileErrChan chan fileError) { + var out fileError var fil File fil.Reader = file - fil.name = path.Base(file.Name()) fil.path = squashfsPath - stat, err := file.Stat() - if err != nil { - if w.AllowErrors { - log.Println("Error while getting FileInfo for", file.Name()+":") - log.Println(err) - err = nil - } - errChan <- err - return - } - if stat.IsDir() { - fil.filType = inode.BasicDirectoryType - dirs, err := file.Readdirnames(-1) + fil.name = path.Base(file.Name()) + mode := fil.Mode() + + if mode.IsRegular() { + fil.filType = inode.BasicFileType + goto successExit + } else if mode.IsDir() { + fil.filType = inode.BasicSymlinkType + subDirs, err := file.Readdirnames(-1) if err != nil { - if w.AllowErrors { - log.Println("Error when getting directory names for", file.Name()+":") + if w.AllowErrors && !subDir { + log.Println("Can't get sub-directories for", file.Name()) log.Println(err) - err = nil + } else { + out.err = err } - errChan <- err - return + goto failExit } - subDirErrChan := make(chan error) - for _, dir := range dirs { - go func(newFilename string, errChan chan error) { - subFil, err := os.Open(file.Name() + newFilename) + subDirChan := make(chan fileError) + for _, filName := range subDirs { + go func(filename string, returnChan chan fileError) { + subFil, err := os.Open(filename) if err != nil { - if w.AllowErrors { - log.Println("Error when opening sub-directory", subFil.Name()+":") - log.Println(err) - err = nil + out.err = err + returnChan <- fileError{ + err: err, } - errChan <- err return } - subDirErrChan := make(chan error) - w.convertFile(fil.Path(), subFil, subDirErrChan) - errChan <- <-subDirErrChan - return - }(dir, subDirErrChan) + w.convertFile(fil.Path(), subFil, true, subDirChan) + }(file.Name()+filName, subDirChan) } - for range dirs { - err = <-subDirErrChan - if err != nil { - errChan <- err - return + for range subDirs { + filErr := <-subDirChan + if filErr.err != nil { + if w.AllowErrors && !subDir { + log.Println("Error while adding subdirectory of", file.Name()) + log.Println(filErr.err) + } else if subDir { + if out.err == nil { + out.err = filErr.err + } + } else { + out.err = err + goto failExit + } + continue } + out.files = append(out.files, filErr.files...) } - w.temp = append(w.temp, &fil) - errChan <- nil - return - } else if stat.Mode().IsRegular() { - fil.filType = inode.BasicFileType - w.temp = append(w.temp, &fil) - errChan <- nil - return - } else if stat.Mode()&os.ModeSymlink == os.ModeSymlink { - linkLocation, err := os.Readlink(file.Name()) + goto successExit + } else if mode&os.ModeSymlink == os.ModeSymlink { + fil.filType = inode.BasicSymlinkType + symLocation, err := os.Readlink(file.Name()) if err != nil { - if w.AllowErrors { - log.Println("Error when reading symlink's target", file.Name()+":") + if w.AllowErrors && !subDir { + log.Println("Error while getting symlink's information for", file.Name()) log.Println(err) - err = nil + } else { + out.err = err } - errChan <- err - return + goto failExit } if w.ResolveSymlinks { - if w.symlinkTable[linkLocation] != "" { - linkLocation = w.symlinkTable[linkLocation] + if val, ok := w.symlinkTable[symLocation]; ok { + symLocation = val + } else if val, ok := w.symTableTemp[symLocation]; ok { + symLocation = val + } else { + //TODO: either add the file, or place the file in this location. Maybe defer this until after all the other files are added? } } - //TODO: finish symlink support + //TODO: store the symLocation inside the File somehow.... } - errChan <- errors.New("Unsupported file type") + if w.AllowErrors && !subDir { + log.Println("Unsupported file type for", file.Name()) + } else { + out.err = errors.New("Unsupported file type") + } +failExit: //before this is used, make sure to log or set the error. + fileErrChan <- out + return +successExit: + out.files = []*File{&fil} + fileErrChan <- out return } @@ -145,33 +159,10 @@ func (w *Writer) AddFilesToPath(squashfsPath string, files ...*os.File) error { if squashfsPath == "." { squashfsPath = "/" } - errChan := make(chan error) + fileErrChan := make(chan fileError) for _, fil := range files { - go w.convertFile(squashfsPath, fil, errChan) + go w.convertFile(squashfsPath, fil, false, fileErrChan) } - var firstError error - for range files { - err := <-errChan - if firstError != nil && err != nil { - firstError = err - } - } - if firstError != nil { - w.temp = nil - return firstError - } - for _, tempFil := range w.temp { - if tempFil.path != "/" { - ind := sort.SearchStrings(w.directories, tempFil.path) - if ind == len(w.directories) { - w.directories = append(w.directories, tempFil.path) - sort.Strings(w.directories) - } - } - w.files[tempFil.path] = append(w.files[tempFil.path], tempFil) - } - w.temp = nil - return nil } //AddFiles adds all files given to the root directory