init
This commit is contained in:
214
ca-server/vendor/github.com/go-sql-driver/mysql/compress.go
generated
vendored
Normal file
214
ca-server/vendor/github.com/go-sql-driver/mysql/compress.go
generated
vendored
Normal file
@@ -0,0 +1,214 @@
|
||||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2024 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/zlib"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
zrPool *sync.Pool // Do not use directly. Use zDecompress() instead.
|
||||
zwPool *sync.Pool // Do not use directly. Use zCompress() instead.
|
||||
)
|
||||
|
||||
func init() {
|
||||
zrPool = &sync.Pool{
|
||||
New: func() any { return nil },
|
||||
}
|
||||
zwPool = &sync.Pool{
|
||||
New: func() any {
|
||||
zw, err := zlib.NewWriterLevel(new(bytes.Buffer), 2)
|
||||
if err != nil {
|
||||
panic(err) // compress/zlib return non-nil error only if level is invalid
|
||||
}
|
||||
return zw
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func zDecompress(src []byte, dst *bytes.Buffer) (int, error) {
|
||||
br := bytes.NewReader(src)
|
||||
var zr io.ReadCloser
|
||||
var err error
|
||||
|
||||
if a := zrPool.Get(); a == nil {
|
||||
if zr, err = zlib.NewReader(br); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
zr = a.(io.ReadCloser)
|
||||
if err := zr.(zlib.Resetter).Reset(br, nil); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
n, _ := dst.ReadFrom(zr) // ignore err because zr.Close() will return it again.
|
||||
err = zr.Close() // zr.Close() may return chuecksum error.
|
||||
zrPool.Put(zr)
|
||||
return int(n), err
|
||||
}
|
||||
|
||||
func zCompress(src []byte, dst io.Writer) error {
|
||||
zw := zwPool.Get().(*zlib.Writer)
|
||||
zw.Reset(dst)
|
||||
if _, err := zw.Write(src); err != nil {
|
||||
return err
|
||||
}
|
||||
err := zw.Close()
|
||||
zwPool.Put(zw)
|
||||
return err
|
||||
}
|
||||
|
||||
type compIO struct {
|
||||
mc *mysqlConn
|
||||
buff bytes.Buffer
|
||||
}
|
||||
|
||||
func newCompIO(mc *mysqlConn) *compIO {
|
||||
return &compIO{
|
||||
mc: mc,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *compIO) reset() {
|
||||
c.buff.Reset()
|
||||
}
|
||||
|
||||
func (c *compIO) readNext(need int, r readerFunc) ([]byte, error) {
|
||||
for c.buff.Len() < need {
|
||||
if err := c.readCompressedPacket(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
data := c.buff.Next(need)
|
||||
return data[:need:need], nil // prevent caller writes into c.buff
|
||||
}
|
||||
|
||||
func (c *compIO) readCompressedPacket(r readerFunc) error {
|
||||
header, err := c.mc.buf.readNext(7, r) // size of compressed header
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = header[6] // bounds check hint to compiler; guaranteed by readNext
|
||||
|
||||
// compressed header structure
|
||||
comprLength := getUint24(header[0:3])
|
||||
compressionSequence := uint8(header[3])
|
||||
uncompressedLength := getUint24(header[4:7])
|
||||
if debug {
|
||||
fmt.Printf("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n",
|
||||
comprLength, uncompressedLength, compressionSequence, c.mc.sequence)
|
||||
}
|
||||
// Do not return ErrPktSync here.
|
||||
// Server may return error packet (e.g. 1153 Got a packet bigger than 'max_allowed_packet' bytes)
|
||||
// before receiving all packets from client. In this case, seqnr is younger than expected.
|
||||
// NOTE: Both of mariadbclient and mysqlclient do not check seqnr. Only server checks it.
|
||||
if debug && compressionSequence != c.mc.sequence {
|
||||
fmt.Printf("WARN: unexpected cmpress seq nr: expected %v, got %v",
|
||||
c.mc.sequence, compressionSequence)
|
||||
}
|
||||
c.mc.sequence = compressionSequence + 1
|
||||
c.mc.compressSequence = c.mc.sequence
|
||||
|
||||
comprData, err := c.mc.buf.readNext(comprLength, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if payload is uncompressed, its length will be specified as zero, and its
|
||||
// true length is contained in comprLength
|
||||
if uncompressedLength == 0 {
|
||||
c.buff.Write(comprData)
|
||||
return nil
|
||||
}
|
||||
|
||||
// use existing capacity in bytesBuf if possible
|
||||
c.buff.Grow(uncompressedLength)
|
||||
nread, err := zDecompress(comprData, &c.buff)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if nread != uncompressedLength {
|
||||
return fmt.Errorf("invalid compressed packet: uncompressed length in header is %d, actual %d",
|
||||
uncompressedLength, nread)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const minCompressLength = 150
|
||||
const maxPayloadLen = maxPacketSize - 4
|
||||
|
||||
// writePackets sends one or some packets with compression.
|
||||
// Use this instead of mc.netConn.Write() when mc.compress is true.
|
||||
func (c *compIO) writePackets(packets []byte) (int, error) {
|
||||
totalBytes := len(packets)
|
||||
blankHeader := make([]byte, 7)
|
||||
buf := &c.buff
|
||||
|
||||
for len(packets) > 0 {
|
||||
payloadLen := min(maxPayloadLen, len(packets))
|
||||
payload := packets[:payloadLen]
|
||||
uncompressedLen := payloadLen
|
||||
|
||||
buf.Reset()
|
||||
buf.Write(blankHeader) // Buffer.Write() never returns error
|
||||
|
||||
// If payload is less than minCompressLength, don't compress.
|
||||
if uncompressedLen < minCompressLength {
|
||||
buf.Write(payload)
|
||||
uncompressedLen = 0
|
||||
} else {
|
||||
err := zCompress(payload, buf)
|
||||
if debug && err != nil {
|
||||
fmt.Printf("zCompress error: %v", err)
|
||||
}
|
||||
// do not compress if compressed data is larger than uncompressed data
|
||||
// I intentionally miss 7 byte header in the buf; zCompress must compress more than 7 bytes.
|
||||
if err != nil || buf.Len() >= uncompressedLen {
|
||||
buf.Reset()
|
||||
buf.Write(blankHeader)
|
||||
buf.Write(payload)
|
||||
uncompressedLen = 0
|
||||
}
|
||||
}
|
||||
|
||||
if n, err := c.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil {
|
||||
// To allow returning ErrBadConn when sending really 0 bytes, we sum
|
||||
// up compressed bytes that is returned by underlying Write().
|
||||
return totalBytes - len(packets) + n, err
|
||||
}
|
||||
packets = packets[payloadLen:]
|
||||
}
|
||||
|
||||
return totalBytes, nil
|
||||
}
|
||||
|
||||
// writeCompressedPacket writes a compressed packet with header.
|
||||
// data should start with 7 size space for header followed by payload.
|
||||
func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) (int, error) {
|
||||
mc := c.mc
|
||||
comprLength := len(data) - 7
|
||||
if debug {
|
||||
fmt.Printf(
|
||||
"writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v",
|
||||
comprLength, uncompressedLen, mc.compressSequence)
|
||||
}
|
||||
|
||||
// compression header
|
||||
putUint24(data[0:3], comprLength)
|
||||
data[3] = mc.compressSequence
|
||||
putUint24(data[4:7], uncompressedLen)
|
||||
|
||||
mc.compressSequence++
|
||||
return mc.writeWithTimeout(data)
|
||||
}
|
||||
Reference in New Issue
Block a user