Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 183 additions & 61 deletions src/freq01.scala
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
import java.nio.charset.StandardCharsets

import java.io._
import java.nio.ByteBuffer
import java.nio.channels.FileChannel
import java.nio.charset.StandardCharsets
import java.nio.file.Paths
import java.util

import cats.data._
import cats.data.{ NonEmptyList, ValidatedNel }
import cats.effect._
import cats.implicits._
import com.monovore.decline._
import com.monovore.decline.effect._
import fs2._
import it.unimi.dsi.fastutil.bytes.ByteArrayList
import it.unimi.dsi.fastutil.io.FastBufferedInputStream
import it.unimi.dsi.fastutil.objects._
import it.unimi.dsi.fastutil.objects.ObjectArrays

import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.annotation.tailrec

package freq01 {

object App extends CommandIOApp("freq", "Counts '[a-zA-Z]+' words in input", version = "0.1.0") {
private val a = 'a'.toByte
private val z = 'z'.toByte
Expand All @@ -30,25 +28,9 @@ package freq01 {
.use { blocker =>
args
.input[IO](blocker)
.through(words)
.fold(new Object2IntOpenHashMap[String]()) { (dict, word) =>
dict.put(word, dict.getOrDefault(word, 0) + 1)
dict
}
.through(collect)
.flatMap { dict =>
val ww = new ObjectArrayList[(Int, String)](dict.size())

dict.object2IntEntrySet().fastForEach { e =>
val _ = ww.add(e.getIntValue -> e.getKey)
}
ww.unstableSort({
case ((c1: Int, w1: String), (c2: Int, w2: String)) =>
val ints = c2 - c1
if (ints == 0) w1.compareTo(w2)
else ints
})

Stream.fromIterator[IO](ww.iterator().asScala)
Stream.fromIterator[IO](dict.drain)
}
.map {
case (counter, word) => f"$counter%d $word%s%n"
Expand All @@ -60,58 +42,69 @@ package freq01 {
.as(ExitCode.Success)
}

def words[F[_]]: Pipe[F, Byte, String] = { bytes =>
def bal2String(bytes: ByteArrayList): String = {
val arr = new Array[Byte](bytes.size())
bytes.toArray(arr)
bytes.clear()

new String(arr, StandardCharsets.UTF_8)
}
def collect[F[_]]: Pipe[F, ByteBuffer, FrequencyDict] = { buffers =>
def loop(
s: Stream[F, ByteBuffer],
dict: FrequencyDict,
lastHash: Int,
lastWord: ByteArrayList
): Pull[F, FrequencyDict, Unit] =
s.pull.uncons1.flatMap {
case None =>
if (lastWord.isEmpty) Pull.output1(dict)
else Pull.output1(dict.register(lastHash, lastWord.toArray(new Array[Byte](lastWord.size()))))
case Some(buffer -> nxt) =>
var hash = lastHash
val word = lastWord

def loop(s: Stream[F, Byte], rem: ByteArrayList): Pull[F, String, Unit] =
s.pull.uncons.flatMap {
case None if rem.isEmpty => Pull.done
case None => Pull.output1(bal2String(rem))
case Some(chunk -> s) =>
val words = mutable.ArrayBuilder.make[String]
val word = rem
chunk.foreach { byte =>
while (buffer.remaining() > 0) {
var byte = buffer.get()
if (a <= byte && byte <= z) {
val _ = word.add(byte)
hash = Fnv1.next(hash, byte)
word.add(byte)
} else if (A <= byte && byte <= Z) {
val _ = word.add((byte ^ 0x020).toByte)
byte = (byte | 0x20).toByte
hash = Fnv1.next(hash, byte)
word.add(byte)
} else if (!word.isEmpty) {
words += bal2String(word)
dict.register(hash, word.toArray(new Array[Byte](word.size())))
word.clear()
hash = Fnv1.H
}
}

Pull.output(Chunk.seq(words.result())) >> loop(s, word)
loop(nxt, dict, hash, word)
}

loop(bytes, new ByteArrayList(16)).stream
loop(buffers, FrequencyDict(), Fnv1.H, new ByteArrayList(256)).stream
}

}

final case class Args(in: Option[File], out: Option[File], chunkSize: Int, bufferSize: Option[Int]) {

def input[F[_]](blocker: Blocker)(implicit F: Sync[F], CS: ContextShift[F]): Stream[F, Byte] =
def input[F[_]](blocker: Blocker)(implicit F: Sync[F], CS: ContextShift[F]): Stream[F, ByteBuffer] =
in.map { file =>
val fis = F
.catchNonFatal {
bufferSize
.map { size =>
new FastBufferedInputStream(new FileInputStream(file), size)
}
.getOrElse {
new FastBufferedInputStream(new FileInputStream(file))
}
Stream
.bracket {
blocker.delay(new FileInputStream(file).getChannel)
} { channel =>
blocker.delay(channel.close())
}
.flatMap { channel =>
Stream.unfoldEval(0L -> math.min(channel.size(), Int.MaxValue.toLong)) {
case (_, 0L) => F.pure(none[(ByteBuffer, (Long, Long))])
case (p, sz) =>
val size = math.min(sz, Int.MaxValue.toLong)
blocker
.delay(channel.map(FileChannel.MapMode.READ_ONLY, p, size))
.widen[ByteBuffer]
.tupleRight((p + size) -> (channel.size() - p - size))
.map(_.some)
}
}
.widen[InputStream]
io.readInputStream(fis, chunkSize, blocker)
}
.getOrElse {
io.stdin(chunkSize, blocker)
io.stdin(chunkSize, blocker).chunks.map(_.toByteBuffer)
}

def output[F[_]](blocker: Blocker)(implicit F: Sync[F], CS: ContextShift[F]): Pipe[F, String, Unit] = { lines =>
Expand Down Expand Up @@ -181,4 +174,133 @@ package freq01 {
val parse: Opts[Args] = (in, out, chunkSize, bufferSize).mapN(Args(_, _, _, _))
}

object Fnv1 {
val H: Int = 0x811c9dc5
val P: Int = 0x01000193

def next(hash: Int, value: Byte): Int = (hash ^ value) * P
}

final class FrequencyDict(initial: Int) {
import FrequencyDict.{ LoadFactor, Value }

private var capacity = initial
private var length = 0
private var mask = capacity - 1
private var max = (LoadFactor * capacity.toFloat).toInt

private var hashes = new Array[Int](capacity)
private var values = new Array[Value](capacity)

def size: Int = length

def drain: Iterator[(Int, String)] = {
val data = values

hashes = new Array[Int](initial)
values = new Array[Value](initial)

capacity = initial
length = 0
mask = capacity - 1
max = (LoadFactor * capacity.toFloat).toInt

// scalafix:off DisableSyntax.null; keeping buckets sparse for locality
ObjectArrays.unstableSort(data, { (l: Value, r: Value) =>
if ((l ne null) && (r ne null)) l.compareTo(r)
else if (l ne null) -1
else if (r ne null) 1
else 0
})
//scalafix:on

data.iterator
.takeWhile(_ ne null) // scalafix:ok DisableSyntax.null
.map(v => (v.value, v.key))
}

def register(hash: Int, key: Array[Byte]): this.type = {
val hsh = if (hash == 0) Fnv1.H else hash
@tailrec def loop(idx: Int): Unit = {
val idxHash = hashes(idx)
if (idxHash == 0) {
hashes(idx) = hsh
values(idx) = Value(key)
length += 1

if (length > max) ensureCapacity()
} else if (idxHash != hsh || !values(idx).update(key)) {
loop((idx + 1) & mask)
}
}

loop(hsh & mask)
this
}

private def ensureCapacity(): Unit = {
while (length > max) {
capacity *= 2
mask = capacity - 1
max = (LoadFactor * capacity.toFloat).toInt
}

val newHashes = new Array[Int](capacity)
val newValues = new Array[Value](capacity)

hashes.iterator.zipWithIndex.filter(_._1 != 0).foreach {
case (hash, i) =>
var idx = hash & mask
while (newHashes(idx) != 0) idx = (idx + 1) & mask

newHashes(idx) = hash
newValues(idx) = values(i)
}

hashes = newHashes
values = newValues
}

}

object FrequencyDict {

final class Value(private val bytes: Array[Byte]) extends Comparable[Value] {
private var counter: Int = 1

lazy val key: String = new String(bytes, StandardCharsets.UTF_8)
def value: Int = counter

def update(arr: Array[Byte]): Boolean = {
val same = util.Arrays.equals(bytes, arr)
if (same) counter += 1
same
}

def compareTo(that: Value): Int = {
var r = that.counter - this.counter
var i = 0
while (r == 0 && i < this.bytes.length && i < that.bytes.length) {
r = this.bytes(i) - that.bytes(i)
i += 1
}

if (r == 0 && i < that.bytes.length) -1
else if (r == 0 && i < this.bytes.length) 1
else r
}
}

object Value {
def apply(bytes: Array[Byte]): Value =
new Value(bytes)
}

val InitialCapacity = 128
val LoadFactor = 0.9f

def apply(): FrequencyDict = new FrequencyDict(InitialCapacity)

}

}