diff --git a/omod/src/main/java/org/openmrs/module/attachments/rest/AttachmentResource.java b/omod/src/main/java/org/openmrs/module/attachments/rest/AttachmentResource.java index 3097b3f..da3f9f1 100644 --- a/omod/src/main/java/org/openmrs/module/attachments/rest/AttachmentResource.java +++ b/omod/src/main/java/org/openmrs/module/attachments/rest/AttachmentResource.java @@ -2,26 +2,21 @@ import static org.openmrs.module.attachments.AttachmentsContext.getContentFamily; +import java.io.ByteArrayInputStream; import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; -import java.io.OutputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import io.swagger.models.Model; import io.swagger.models.ModelImpl; -import io.swagger.models.properties.ByteArrayProperty; import io.swagger.models.properties.DateProperty; -import io.swagger.models.properties.RefProperty; import io.swagger.models.properties.StringProperty; import org.apache.commons.codec.binary.Base64; -import org.apache.commons.lang.ArrayUtils; -import org.apache.commons.lang.BooleanUtils; -import org.apache.commons.lang.StringUtils; +import org.apache.commons.lang3.BooleanUtils; +import org.apache.commons.lang3.StringUtils; import org.openmrs.Encounter; import org.openmrs.Obs; import org.openmrs.Patient; @@ -80,7 +75,6 @@ public Attachment getByUniqueId(String uniqueId) { if (!obs.isComplex()) throw new GenericRestException(uniqueId + " does not identify a complex obs.", null); else { - obs = Context.getObsService().getComplexObs(obs.getId(), AttachmentsConstants.ATT_VIEW_CRUD); return new Attachment(obs, Context.getRegisteredComponent(AttachmentsConstants.COMPONENT_ATT_CONTEXT, AttachmentsContext.class) .getComplexDataHelper()); @@ -116,7 +110,7 @@ public Object upload(MultipartFile file, RequestContext context) throws Response AttachmentsContext.class); if (base64Content != null) { - file = new Base64MultipartFile(base64Content); + file = new Base64MultipartFile(base64Content, file.getName(), file.getOriginalFilename()); } // Verify File Size if (ctx.getMaxUploadFileSize() * 1024 * 1024 < (double) file.getSize()) { @@ -127,14 +121,17 @@ public Object upload(MultipartFile file, RequestContext context) throws Response String fileName = file.getOriginalFilename(); int idx = fileName.lastIndexOf("."); String fileExtension = idx > 0 && idx < fileName.length() - 1 ? fileName.substring(idx + 1) : ""; - if (!ArrayUtils.isEmpty(ctx.getAllowedFileExtensions()) && !Arrays.stream(ctx.getAllowedFileExtensions()) - .filter(e -> e.equalsIgnoreCase(fileExtension)).findAny().isPresent()) { - throw new IllegalRequestException("The extension is not valid"); + + String[] allowedExtensions = ctx.getAllowedFileExtensions(); + if (allowedExtensions != null && allowedExtensions.length > 0 && Arrays.stream(allowedExtensions) + .filter(s -> s != null && !s.isEmpty()).noneMatch(fileExtension::equalsIgnoreCase)) { + throw new IllegalRequestException("The extension " + fileExtension + " is not valid"); } // Verify file name - if (!ArrayUtils.isEmpty(ctx.getDeniedFileNames()) - && Arrays.stream(ctx.getDeniedFileNames()).filter(e -> e.equalsIgnoreCase(fileName)).findAny().isPresent()) { + String[] deniedFileNames = ctx.getDeniedFileNames(); + if (deniedFileNames != null && deniedFileNames.length > 0 && Arrays.stream(deniedFileNames) + .filter(s -> s != null && !s.isEmpty()).anyMatch(fileName::equalsIgnoreCase)) { throw new IllegalRequestException("The file name is not valid"); } @@ -326,34 +323,32 @@ protected PageableResult doSearch(RequestContext context) { * dependency to MockMultipartFile for converting the base64 encoded String to a MultipartFile * object. */ - class Base64MultipartFile implements MultipartFile { + static final class Base64MultipartFile implements MultipartFile { + + private final String fileName; - private String fileName; + private final String originalFileName; - private String contentType; + private final String contentType; - private long size; + private final long size; - private InputStream in; + private final InputStream in; - private byte[] bytes; + private final byte[] bytes; - public Base64MultipartFile(String base64Image) throws IOException { - String[] parts = base64Image.split(","); + public Base64MultipartFile(String base64Image, String fileName, String originalFileName) throws IOException { + String[] parts = base64Image.split(",", 2); String contentType = parts[0].split(":")[1].split(";")[0].trim(); String contents = parts[1].trim(); byte[] decodedImage = Base64.decodeBase64(contents.getBytes()); - final File temp = File.createTempFile("cameracapture", ".png"); - try (OutputStream stream = new FileOutputStream(temp)) { - stream.write(decodedImage); - } - temp.deleteOnExit(); - this.fileName = temp.getName(); - this.in = new FileInputStream(temp); + this.fileName = fileName; + this.originalFileName = originalFileName; + this.in = new ByteArrayInputStream(decodedImage); this.contentType = contentType; this.bytes = decodedImage; - this.size = temp.length(); + this.size = decodedImage.length; } @Override @@ -363,7 +358,7 @@ public String getName() { @Override public String getOriginalFilename() { - return this.fileName; + return this.originalFileName; } @Override @@ -382,17 +377,17 @@ public long getSize() { } @Override - public byte[] getBytes() throws IOException { + public byte[] getBytes() { return this.bytes; } @Override - public InputStream getInputStream() throws IOException { + public InputStream getInputStream() { return this.in; } @Override - public void transferTo(File dest) throws IOException, IllegalStateException { + public void transferTo(File dest) throws IllegalStateException { throw new APIException("Operation transferTo is not supported for Base64MultipartFile"); } } diff --git a/omod/src/test/java/org/openmrs/module/attachments/rest/AttachmentRestControllerTest.java b/omod/src/test/java/org/openmrs/module/attachments/rest/AttachmentRestControllerTest.java index f30e113..136fc7b 100644 --- a/omod/src/test/java/org/openmrs/module/attachments/rest/AttachmentRestControllerTest.java +++ b/omod/src/test/java/org/openmrs/module/attachments/rest/AttachmentRestControllerTest.java @@ -12,6 +12,7 @@ import java.io.InputStream; import java.util.ArrayList; import java.util.LinkedHashMap; +import java.util.Objects; import java.util.Random; import javax.imageio.ImageIO; import javax.servlet.http.HttpServletRequest; @@ -256,8 +257,7 @@ public void postAttachment_shouldUploadFileToVisit() throws Exception { SimpleObject response = deserialize(handle(request)); fileName = "testFile1_" + (String) response.get("uuid") + ".dat"; Obs obs = Context.getObsService().getObsByUuid((String) response.get("uuid")); - Obs complexObs = Context.getObsService().getComplexObs(obs.getObsId(), null); - ComplexData complexData = complexObs.getComplexData(); + ComplexData complexData = obs.getComplexData(); // Verify Assert.assertEquals(obs.getComment(), fileCaption); @@ -286,10 +286,9 @@ public void postAttachment_shouldUploadFileAsEncounterless() throws Exception { // Replay SimpleObject response = deserialize(handle(request)); - Obs obs = Context.getObsService().getObsByUuid((String) response.get("uuid")); - fileName = "testFile2_" + (String) response.get("uuid") + ".dat"; - Obs complexObs = Context.getObsService().getComplexObs(obs.getObsId(), null); - ComplexData complexData = complexObs.getComplexData(); + Obs obs = Context.getObsService().getObsByUuid(response.get("uuid")); + fileName = "testFile2_" + response.get("uuid") + ".dat"; + ComplexData complexData = obs.getComplexData(); // Verify Assert.assertEquals(obs.getComment(), fileCaption); @@ -319,8 +318,7 @@ public void postAttachment_shouldUploadFileToEncounter() throws Exception { SimpleObject response = deserialize(handle(request)); fileName = "testFile3_" + (String) response.get("uuid") + ".dat"; Obs obs = Context.getObsService().getObsByUuid((String) response.get("uuid")); - Obs complexObs = Context.getObsService().getComplexObs(obs.getObsId(), null); - ComplexData complexData = complexObs.getComplexData(); + ComplexData complexData = obs.getComplexData(); // Verify Assert.assertEquals(obs.getComment(), fileCaption); @@ -332,15 +330,19 @@ public void postAttachment_shouldUploadFileToEncounter() throws Exception { @Test public void postAttachment_shouldAcceptBase64Content() throws Exception { // Read file OpenMRS_logo.png and copy bytes to baos - InputStream inputStream = getClass().getClassLoader().getResourceAsStream("OpenMRS_logo.png"); - BufferedImage img = ImageIO.read(inputStream); + BufferedImage img; + try (InputStream inputStream = getClass().getClassLoader().getResourceAsStream("OpenMRS_logo.png")) { + Objects.requireNonNull(inputStream); + img = ImageIO.read(inputStream); + } + ByteArrayOutputStream baos = new ByteArrayOutputStream(); ImageIO.write(img, "png", baos); // Build the request parameters byte[] bytesIn = baos.toByteArray(); String fileCaption = "Test file caption"; - String fileName = "testFile2.dat"; + String fileName = "testFile2.png"; String base64Content = "data:image/png;base64," + Base64.encodeBase64String(bytesIn); Patient patient = Context.getPatientService().getPatient(2); @@ -355,14 +357,13 @@ public void postAttachment_shouldAcceptBase64Content() throws Exception { // Replay SimpleObject response = deserialize(handle(request)); - Obs obs = Context.getObsService().getObsByUuid((String) response.get("uuid")); - Obs complexObs = Context.getObsService().getComplexObs(obs.getObsId(), null); - ComplexData complexData = complexObs.getComplexData(); + Obs obs = Context.getObsService().getObsByUuid(response.get("uuid")); + ComplexData complexData = obs.getComplexData(); byte[] bytesOut = BaseComplexData.getByteArray(complexData); // Verify Assert.assertEquals(obs.getComment(), fileCaption); - Assert.assertTrue(complexData.getTitle().startsWith("cameracapture")); + Assert.assertTrue(complexData.getTitle().startsWith("testFile2")); Assert.assertArrayEquals(bytesIn, bytesOut); Assert.assertNull(obs.getEncounter()); }