media more or less working

This commit is contained in:
tsmethurst
2021-03-30 16:06:08 +02:00
parent 572149fa0e
commit 362ccf5817
12 changed files with 112 additions and 363 deletions

View File

@ -155,6 +155,9 @@ type DB interface {
// By the time this function is called, it should be assumed that all the parameters have passed validation! // By the time this function is called, it should be assumed that all the parameters have passed validation!
NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string) (*model.User, error) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string) (*model.User, error)
// SetHeaderOrAvatarForAccountID sets the header or avatar for the given accountID to the given media attachment.
SetHeaderOrAvatarForAccountID(mediaAttachmen *model.MediaAttachment, accountID string) error
/* /*
USEFUL CONVERSION FUNCTIONS USEFUL CONVERSION FUNCTIONS
*/ */

View File

@ -50,9 +50,9 @@ type MediaAttachment struct {
// What is the processing status of this attachment // What is the processing status of this attachment
Processing ProcessingStatus Processing ProcessingStatus
// metadata for the whole file // metadata for the whole file
File File File
// small image thumbnail derived from a larger image, video, or audio file. // small image thumbnail derived from a larger image, video, or audio file.
Thumbnail Thumbnail Thumbnail
// Is this attachment being used as an avatar? // Is this attachment being used as an avatar?
Avatar bool Avatar bool
// Is this attachment being used as a header? // Is this attachment being used as a header?
@ -68,7 +68,7 @@ type File struct {
// What is the size of the file in bytes. // What is the size of the file in bytes.
FileSize int FileSize int
// When was the file last updated. // When was the file last updated.
UpdatedAt time.Time `pg:"type:timestamp,default:now()"` UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
} }
// Thumbnail refers to a small image thumbnail derived from a larger image, video, or audio file. // Thumbnail refers to a small image thumbnail derived from a larger image, video, or audio file.
@ -80,7 +80,7 @@ type Thumbnail struct {
// What is the size of the file in bytes // What is the size of the file in bytes
FileSize int FileSize int
// When was the file last updated // When was the file last updated
UpdatedAt time.Time `pg:"type:timestamp,default:now()"` UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// What is the remote URL of the thumbnail // What is the remote URL of the thumbnail
RemoteURL string RemoteURL string
} }
@ -113,50 +113,12 @@ const (
FileTypeVideo FileType = "video" FileTypeVideo FileType = "video"
) )
/*
FILEMETA INTERFACES
*/
// FileMeta describes metadata about the actual contents of the file. // FileMeta describes metadata about the actual contents of the file.
type FileMeta interface { type FileMeta struct {
GetOriginal() OriginalMeta Original Original
GetSmall() SmallMeta Small Small
} }
// OriginalMeta contains info about the originally submitted media
type OriginalMeta interface {
// GetWidth gets the width of a video or image or gif in pixels.
GetWidth() int
// GetHeight gets the height of a video or image or gif in pixels.
GetHeight() int
// GetSize gets the total area of a video or image or gif in pixels (width * height).
GetSize() int
// GetAspect gets the aspect ratio of a video or image or gif in pixels (width / height).
GetAspect() float64
// GetFrameRate gets the FPS of a video or gif.
GetFrameRate() float64
// GetDuration gets the length in seconds of a video or gif or audio file.
GetDuration() float64
// GetBitrate gets the bits per second of a video, gif, or audio file.
GetBitrate() float64
}
// SmallMeta contains info about the derived thumbnail for the submitted media
type SmallMeta interface {
// GetWidth gets the width of a video or image or gif in pixels.
GetWidth() int
// GetHeight gets the height of a video or image or gif in pixels.
GetHeight() int
// GetSize gets the total area of a video or image or gif in pixels (width * height).
GetSize() int
// GetAspect gets the aspect ratio of a video or image or gif in pixels (width / height).
GetAspect() float64
}
/*
FILE META IMPLEMENTATIONS
*/
// Small implements SmallMeta and can be used for a thumbnail of any media type // Small implements SmallMeta and can be used for a thumbnail of any media type
type Small struct { type Small struct {
Width int Width int
@ -165,70 +127,10 @@ type Small struct {
Aspect float64 Aspect float64
} }
func (s Small) GetWidth() int {
return s.Width
}
func (s Small) GetHeight() int {
return s.Height
}
func (s Small) GetSize() int {
return s.Height * s.Width
}
func (s Small) GetAspect() float64 {
return float64(s.Width) / float64(s.Height)
}
// STILL IMAGES
// ImageFileMeta implements FileMeta for still images.
type ImageFileMeta struct {
Original ImageOriginal
Small Small
}
func (m ImageFileMeta) GetOriginal() OriginalMeta {
return m.Original
}
func (m ImageFileMeta) GetSmall() SmallMeta {
return m.Small
}
// ImageOriginal implements OriginalMeta for still images // ImageOriginal implements OriginalMeta for still images
type ImageOriginal struct { type Original struct {
Width int Width int
Height int Height int
Size int Size int
Aspect float64 Aspect float64
} }
func (o ImageOriginal) GetWidth() int {
return o.Width
}
func (o ImageOriginal) GetHeight() int {
return o.Height
}
func (o ImageOriginal) GetSize() int {
return o.Height * o.Width
}
func (o ImageOriginal) GetAspect() float64 {
return float64(o.Width) / float64(o.Height)
}
func (o ImageOriginal) GetFrameRate() float64 {
return 0
}
func (o ImageOriginal) GetDuration() float64 {
return 0
}
func (o ImageOriginal) GetBitrate() float64 {
return 0
}

View File

@ -1,42 +0,0 @@
// Code generated by mockery v2.7.4. DO NOT EDIT.
package model
import mock "github.com/stretchr/testify/mock"
// MockFileMeta is an autogenerated mock type for the FileMeta type
type MockFileMeta struct {
mock.Mock
}
// GetOriginal provides a mock function with given fields:
func (_m *MockFileMeta) GetOriginal() OriginalMeta {
ret := _m.Called()
var r0 OriginalMeta
if rf, ok := ret.Get(0).(func() OriginalMeta); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(OriginalMeta)
}
}
return r0
}
// GetSmall provides a mock function with given fields:
func (_m *MockFileMeta) GetSmall() SmallMeta {
ret := _m.Called()
var r0 SmallMeta
if rf, ok := ret.Get(0).(func() SmallMeta); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(SmallMeta)
}
}
return r0
}

View File

@ -1,108 +0,0 @@
// Code generated by mockery v2.7.4. DO NOT EDIT.
package model
import mock "github.com/stretchr/testify/mock"
// MockOriginalMeta is an autogenerated mock type for the OriginalMeta type
type MockOriginalMeta struct {
mock.Mock
}
// GetAspect provides a mock function with given fields:
func (_m *MockOriginalMeta) GetAspect() float64 {
ret := _m.Called()
var r0 float64
if rf, ok := ret.Get(0).(func() float64); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(float64)
}
return r0
}
// GetBitrate provides a mock function with given fields:
func (_m *MockOriginalMeta) GetBitrate() float64 {
ret := _m.Called()
var r0 float64
if rf, ok := ret.Get(0).(func() float64); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(float64)
}
return r0
}
// GetDuration provides a mock function with given fields:
func (_m *MockOriginalMeta) GetDuration() float64 {
ret := _m.Called()
var r0 float64
if rf, ok := ret.Get(0).(func() float64); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(float64)
}
return r0
}
// GetFrameRate provides a mock function with given fields:
func (_m *MockOriginalMeta) GetFrameRate() float64 {
ret := _m.Called()
var r0 float64
if rf, ok := ret.Get(0).(func() float64); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(float64)
}
return r0
}
// GetHeight provides a mock function with given fields:
func (_m *MockOriginalMeta) GetHeight() int {
ret := _m.Called()
var r0 int
if rf, ok := ret.Get(0).(func() int); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int)
}
return r0
}
// GetSize provides a mock function with given fields:
func (_m *MockOriginalMeta) GetSize() int {
ret := _m.Called()
var r0 int
if rf, ok := ret.Get(0).(func() int); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int)
}
return r0
}
// GetWidth provides a mock function with given fields:
func (_m *MockOriginalMeta) GetWidth() int {
ret := _m.Called()
var r0 int
if rf, ok := ret.Get(0).(func() int); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int)
}
return r0
}

View File

@ -1,66 +0,0 @@
// Code generated by mockery v2.7.4. DO NOT EDIT.
package model
import mock "github.com/stretchr/testify/mock"
// MockSmallMeta is an autogenerated mock type for the SmallMeta type
type MockSmallMeta struct {
mock.Mock
}
// GetAspect provides a mock function with given fields:
func (_m *MockSmallMeta) GetAspect() float64 {
ret := _m.Called()
var r0 float64
if rf, ok := ret.Get(0).(func() float64); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(float64)
}
return r0
}
// GetHeight provides a mock function with given fields:
func (_m *MockSmallMeta) GetHeight() int {
ret := _m.Called()
var r0 int
if rf, ok := ret.Get(0).(func() int); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int)
}
return r0
}
// GetSize provides a mock function with given fields:
func (_m *MockSmallMeta) GetSize() int {
ret := _m.Called()
var r0 int
if rf, ok := ret.Get(0).(func() int); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int)
}
return r0
}
// GetWidth provides a mock function with given fields:
func (_m *MockSmallMeta) GetWidth() int {
ret := _m.Called()
var r0 int
if rf, ok := ret.Get(0).(func() int); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int)
}
return r0
}

View File

@ -463,6 +463,11 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr
return u, nil return u, nil
} }
func (ps *postgresService) SetHeaderOrAvatarForAccountID(mediaAttachment *model.MediaAttachment, accountID string) error {
_, err := ps.conn.Model(mediaAttachment).Insert()
return err
}
/* /*
CONVERSION FUNCTIONS CONVERSION FUNCTIONS
*/ */

View File

@ -21,7 +21,8 @@ package media
import ( import (
"errors" "errors"
"fmt" "fmt"
"io" "strings"
"time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -36,7 +37,7 @@ type MediaHandler interface {
// SetHeaderOrAvatarForAccountID takes a new header image for an account, checks it out, removes exif data from it, // SetHeaderOrAvatarForAccountID takes a new header image for an account, checks it out, removes exif data from it,
// puts it in whatever storage backend we're using, sets the relevant fields in the database for the new image, // puts it in whatever storage backend we're using, sets the relevant fields in the database for the new image,
// and then returns information to the caller about the new header. // and then returns information to the caller about the new header.
SetHeaderOrAvatarForAccountID(f io.Reader, accountID string, headerOrAvi string) (*model.MediaAttachment, error) SetHeaderOrAvatarForAccountID(img []byte, accountID string, headerOrAvi string) (*model.MediaAttachment, error)
} }
type mediaHandler struct { type mediaHandler struct {
@ -67,7 +68,7 @@ type HeaderInfo struct {
INTERFACE FUNCTIONS INTERFACE FUNCTIONS
*/ */
func (mh *mediaHandler) SetHeaderOrAvatarForAccountID(f io.Reader, accountID string, headerOrAvi string) (*model.MediaAttachment, error) { func (mh *mediaHandler) SetHeaderOrAvatarForAccountID(img []byte, accountID string, headerOrAvi string) (*model.MediaAttachment, error) {
l := mh.log.WithField("func", "SetHeaderForAccountID") l := mh.log.WithField("func", "SetHeaderForAccountID")
if headerOrAvi != "header" && headerOrAvi != "avatar" { if headerOrAvi != "header" && headerOrAvi != "avatar" {
@ -75,7 +76,7 @@ func (mh *mediaHandler) SetHeaderOrAvatarForAccountID(f io.Reader, accountID str
} }
// make sure we have an image we can handle // make sure we have an image we can handle
contentType, err := parseContentType(f) contentType, err := parseContentType(img)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -83,21 +84,23 @@ func (mh *mediaHandler) SetHeaderOrAvatarForAccountID(f io.Reader, accountID str
return nil, fmt.Errorf("%s is not an accepted image type", contentType) return nil, fmt.Errorf("%s is not an accepted image type", contentType)
} }
// extract the bytes if len(img) == 0 {
imageBytes := []byte{} return nil, fmt.Errorf("passed reader was of size 0")
size, err := f.Read(imageBytes)
if err != nil {
return nil, fmt.Errorf("error reading file bytes: %s", err)
} }
l.Tracef("read %d bytes of file", size) l.Tracef("read %d bytes of file", len(img))
// // close the open file--we don't need it anymore now we have the bytes
// if err := f.Close(); err != nil {
// return nil, fmt.Errorf("error closing file: %s", err)
// }
// process it // process it
return mh.processHeaderOrAvi(imageBytes, contentType, headerOrAvi, accountID) ma, err := mh.processHeaderOrAvi(img, contentType, headerOrAvi, accountID)
if err != nil {
return nil, fmt.Errorf("error processing %s: %s", headerOrAvi, err)
}
// set it in the database
if err := mh.db.SetHeaderOrAvatarForAccountID(ma, accountID); err != nil {
return nil, fmt.Errorf("error putting %s in database: %s", headerOrAvi, err)
}
return ma, nil
} }
/* /*
@ -131,6 +134,8 @@ func (mh *mediaHandler) processHeaderOrAvi(imageBytes []byte, contentType string
} }
case "image/gif": case "image/gif":
clean = imageBytes clean = imageBytes
default:
return nil, errors.New("media type unrecognized")
} }
original, err := deriveImage(clean, contentType) original, err := deriveImage(clean, contentType)
@ -144,12 +149,15 @@ func (mh *mediaHandler) processHeaderOrAvi(imageBytes []byte, contentType string
} }
// now put it in storage, take a new uuid for the name of the file so we don't store any unnecessary info about it // now put it in storage, take a new uuid for the name of the file so we don't store any unnecessary info about it
extension := strings.Split(contentType, "/")[1]
newMediaID := uuid.NewString() newMediaID := uuid.NewString()
originalPath := fmt.Sprintf("/%s/media/%s/original/%s.%s", accountID, headerOrAvi, newMediaID, contentType) // we store the original...
originalPath := fmt.Sprintf("%s/media/%s/original/%s.%s", accountID, headerOrAvi, newMediaID, extension)
if err := mh.storage.StoreFileAt(originalPath, original.image); err != nil { if err := mh.storage.StoreFileAt(originalPath, original.image); err != nil {
return nil, fmt.Errorf("storage error: %s", err) return nil, fmt.Errorf("storage error: %s", err)
} }
smallPath := fmt.Sprintf("/%s/media/%s/small/%s.%s", accountID, headerOrAvi, newMediaID, contentType) // and a thumbnail...
smallPath := fmt.Sprintf("%s/media/%s/small/%s.%s", accountID, headerOrAvi, newMediaID, extension)
if err := mh.storage.StoreFileAt(smallPath, small.image); err != nil { if err := mh.storage.StoreFileAt(smallPath, small.image); err != nil {
return nil, fmt.Errorf("storage error: %s", err) return nil, fmt.Errorf("storage error: %s", err)
} }
@ -158,9 +166,11 @@ func (mh *mediaHandler) processHeaderOrAvi(imageBytes []byte, contentType string
ID: newMediaID, ID: newMediaID,
StatusID: "", StatusID: "",
RemoteURL: "", RemoteURL: "",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
Type: model.FileTypeImage, Type: model.FileTypeImage,
FileMeta: model.ImageFileMeta{ FileMeta: model.FileMeta{
Original: model.ImageOriginal{ Original: model.Original{
Width: original.width, Width: original.width,
Height: original.height, Height: original.height,
Size: original.size, Size: original.size,
@ -182,11 +192,13 @@ func (mh *mediaHandler) processHeaderOrAvi(imageBytes []byte, contentType string
Path: originalPath, Path: originalPath,
ContentType: contentType, ContentType: contentType,
FileSize: len(original.image), FileSize: len(original.image),
UpdatedAt: time.Now(),
}, },
Thumbnail: model.Thumbnail{ Thumbnail: model.Thumbnail{
Path: smallPath, Path: smallPath,
ContentType: contentType, ContentType: contentType,
FileSize: len(small.image), FileSize: len(small.image),
UpdatedAt: time.Now(),
RemoteURL: "", RemoteURL: "",
}, },
Avatar: isAvatar, Avatar: isAvatar,

View File

@ -20,9 +20,12 @@ package media
import ( import (
"context" "context"
"io/ioutil"
"testing" "testing"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
@ -36,7 +39,7 @@ type MediaTestSuite struct {
log *logrus.Logger log *logrus.Logger
db db.DB db db.DB
mediaHandler *mediaHandler mediaHandler *mediaHandler
mockStorage storage.Storage mockStorage *storage.MockStorage
} }
/* /*
@ -61,9 +64,10 @@ func (suite *MediaTestSuite) SetupSuite() {
Database: "postgres", Database: "postgres",
ApplicationName: "gotosocial", ApplicationName: "gotosocial",
} }
c.MediaConfig = &config.MediaConfig{
MaxImageSize: 2 << 20,
}
suite.config = c suite.config = c
suite.config.MediaConfig.MaxImageSize = 2 << 20 // 2 megabits
// use an actual database for this, because it's just easier than mocking one out // use an actual database for this, because it's just easier than mocking one out
database, err := db.New(context.Background(), c, log) database, err := db.New(context.Background(), c, log)
if err != nil { if err != nil {
@ -72,12 +76,14 @@ func (suite *MediaTestSuite) SetupSuite() {
suite.db = database suite.db = database
suite.mockStorage = &storage.MockStorage{} suite.mockStorage = &storage.MockStorage{}
// We don't need storage to do anything for these tests, so just simulate a success and do nothing
suite.mockStorage.On("StoreFileAt", mock.AnythingOfType("string"), mock.AnythingOfType("[]uint8")).Return(nil)
// and finally here's the thing we're actually testing! // and finally here's the thing we're actually testing!
suite.mediaHandler = &mediaHandler{ suite.mediaHandler = &mediaHandler{
config: suite.config, config: suite.config,
db: suite.db, db: suite.db,
storage: &storage.MockStorage{}, storage: suite.mockStorage,
log: log, log: log,
} }
@ -122,6 +128,23 @@ func (suite *MediaTestSuite) TearDownTest() {
ACTUAL TESTS ACTUAL TESTS
*/ */
func (suite *MediaTestSuite) TestSetHeaderOrAvatarForAccountID() {
// load test image
f, err := ioutil.ReadFile("./test/test-jpeg.jpg")
assert.Nil(suite.T(), err)
ma, err := suite.mediaHandler.SetHeaderOrAvatarForAccountID(f, "weeeeeee", "header")
assert.Nil(suite.T(), err)
suite.log.Debugf("%+v", ma)
// attachment should have....
assert.Equal(suite.T(), "weeeeeee", ma.AccountID)
assert.Equal(suite.T(), "LjCZnlvyRkRn_NvzRjWF?urqV@f9", ma.Blurhash)
//TODO: add more checks here, cba right now!
}
// TODO: add tests for sad path, gif, png....
func TestMediaTestSuite(t *testing.T) { func TestMediaTestSuite(t *testing.T) {
suite.Run(t, new(MediaTestSuite)) suite.Run(t, new(MediaTestSuite))
} }

View File

@ -3,8 +3,6 @@
package media package media
import ( import (
io "io"
mock "github.com/stretchr/testify/mock" mock "github.com/stretchr/testify/mock"
model "github.com/superseriousbusiness/gotosocial/internal/db/model" model "github.com/superseriousbusiness/gotosocial/internal/db/model"
) )
@ -14,13 +12,13 @@ type MockMediaHandler struct {
mock.Mock mock.Mock
} }
// SetHeaderOrAvatarForAccountID provides a mock function with given fields: f, accountID, headerOrAvi // SetHeaderOrAvatarForAccountID provides a mock function with given fields: img, accountID, headerOrAvi
func (_m *MockMediaHandler) SetHeaderOrAvatarForAccountID(f io.Reader, accountID string, headerOrAvi string) (*model.MediaAttachment, error) { func (_m *MockMediaHandler) SetHeaderOrAvatarForAccountID(img []byte, accountID string, headerOrAvi string) (*model.MediaAttachment, error) {
ret := _m.Called(f, accountID, headerOrAvi) ret := _m.Called(img, accountID, headerOrAvi)
var r0 *model.MediaAttachment var r0 *model.MediaAttachment
if rf, ok := ret.Get(0).(func(io.Reader, string, string) *model.MediaAttachment); ok { if rf, ok := ret.Get(0).(func([]byte, string, string) *model.MediaAttachment); ok {
r0 = rf(f, accountID, headerOrAvi) r0 = rf(img, accountID, headerOrAvi)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).(*model.MediaAttachment) r0 = ret.Get(0).(*model.MediaAttachment)
@ -28,8 +26,8 @@ func (_m *MockMediaHandler) SetHeaderOrAvatarForAccountID(f io.Reader, accountID
} }
var r1 error var r1 error
if rf, ok := ret.Get(1).(func(io.Reader, string, string) error); ok { if rf, ok := ret.Get(1).(func([]byte, string, string) error); ok {
r1 = rf(f, accountID, headerOrAvi) r1 = rf(img, accountID, headerOrAvi)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
} }

View File

@ -26,7 +26,6 @@ import (
"image/gif" "image/gif"
"image/jpeg" "image/jpeg"
"image/png" "image/png"
"io"
"github.com/buckket/go-blurhash" "github.com/buckket/go-blurhash"
"github.com/h2non/filetype" "github.com/h2non/filetype"
@ -36,9 +35,9 @@ import (
// parseContentType parses the MIME content type from a file, returning it as a string in the form (eg., "image/jpeg"). // parseContentType parses the MIME content type from a file, returning it as a string in the form (eg., "image/jpeg").
// Returns an error if the content type is not something we can process. // Returns an error if the content type is not something we can process.
func parseContentType(f io.Reader) (string, error) { func parseContentType(content []byte) (string, error) {
head := make([]byte, 261) head := make([]byte, 261)
_, err := f.Read(head) _, err := bytes.NewReader(content).Read(head)
if err != nil { if err != nil {
return "", fmt.Errorf("could not read first magic bytes of file: %s", err) return "", fmt.Errorf("could not read first magic bytes of file: %s", err)
} }

View File

@ -20,7 +20,6 @@ package media
import ( import (
"io/ioutil" "io/ioutil"
"os"
"testing" "testing"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -64,7 +63,7 @@ func (suite *MediaUtilTestSuite) TearDownTest() {
*/ */
func (suite *MediaUtilTestSuite) TestParseContentTypeOK() { func (suite *MediaUtilTestSuite) TestParseContentTypeOK() {
f, err := os.Open("./test/test-jpeg.jpg") f, err := ioutil.ReadFile("./test/test-jpeg.jpg")
assert.Nil(suite.T(), err) assert.Nil(suite.T(), err)
ct, err := parseContentType(f) ct, err := parseContentType(f)
assert.Nil(suite.T(), err) assert.Nil(suite.T(), err)
@ -72,7 +71,7 @@ func (suite *MediaUtilTestSuite) TestParseContentTypeOK() {
} }
func (suite *MediaUtilTestSuite) TestParseContentTypeNotOK() { func (suite *MediaUtilTestSuite) TestParseContentTypeNotOK() {
f, err := os.Open("./test/test-corrupted.jpg") f, err := ioutil.ReadFile("./test/test-corrupted.jpg")
assert.Nil(suite.T(), err) assert.Nil(suite.T(), err)
ct, err := parseContentType(f) ct, err := parseContentType(f)
assert.NotNil(suite.T(), err) assert.NotNil(suite.T(), err)
@ -135,6 +134,14 @@ func (suite *MediaUtilTestSuite) TestDeriveThumbnailFromJPEG() {
assert.EqualValues(suite.T(), sampleBytes, imageAndMeta.image) assert.EqualValues(suite.T(), sampleBytes, imageAndMeta.image)
} }
func (suite *MediaUtilTestSuite) TestSupportedImageTypes() {
ok := supportedImageType("image/jpeg")
assert.True(suite.T(), ok)
ok = supportedImageType("image/bmp")
assert.False(suite.T(), ok)
}
func TestMediaUtilTestSuite(t *testing.T) { func TestMediaUtilTestSuite(t *testing.T) {
suite.Run(t, new(MediaUtilTestSuite)) suite.Run(t, new(MediaUtilTestSuite))
} }

View File

@ -161,7 +161,7 @@ func (m *accountModule) accountUpdateCredentialsPATCHHandler(c *gin.Context) {
return return
} }
// TODO: form validation // TODO: proper form validation
// TODO: tidy this code into subfunctions // TODO: tidy this code into subfunctions
if form.Header != nil && form.Header.Size != 0 { if form.Header != nil && form.Header.Size != 0 {
@ -177,7 +177,23 @@ func (m *accountModule) accountUpdateCredentialsPATCHHandler(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("could not read provided header: %s", err)}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("could not read provided header: %s", err)})
return return
} }
headerInfo, err := m.mediaHandler.SetHeaderOrAvatarForAccountID(f, authed.Account.ID, "header")
// extract the bytes
imageBytes := []byte{}
size, err := f.Read(imageBytes)
defer func(){
if err := f.Close(); err != nil {
m.log.Errorf("error closing multipart file: %s", err)
}
}()
if err != nil || size == 0 {
l.Debugf("error processing header: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("could not read provided header: %s", err)})
return
}
// do the setting
headerInfo, err := m.mediaHandler.SetHeaderOrAvatarForAccountID(imageBytes, authed.Account.ID, "header")
if err != nil { if err != nil {
l.Debugf("error processing header: %s", err) l.Debugf("error processing header: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})